mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-15 03:45:58 +00:00
Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 25fbd25b05 | |||
| 34e126ee4b | |||
| ec520e6427 | |||
| 0fb2a75bfb | |||
| 5d61718c80 | |||
| 474c89bac2 | |||
| f43aa78107 | |||
| 47e9570d86 | |||
| 1783da42f4 | |||
| d23eac227f | |||
| 554017a89f | |||
| 6e839342a7 | |||
| 8955b3222a | |||
| c91dacc8e2 | |||
| cad6e89a19 | |||
| 094296440f | |||
| 839fa99237 | |||
| 3475f7cdad | |||
| 83bc2fb1ae | |||
| a17d2ff8f8 | |||
| 420a886e1d | |||
| 579e416459 | |||
| c002596ab4 | |||
| a838546a2b | |||
| bbce6c0ac0 | |||
| 0d3bfe0a76 | |||
| 503eeac788 | |||
| aa015462a7 | |||
| b8f5ed360f | |||
| 76136d22b4 | |||
| dc2ababf00 | |||
| 330a2ff8c5 | |||
| 0367fe6c7a | |||
| c733d3c917 | |||
| b6fbf0d105 | |||
| f401e7baa6 | |||
| 919d8bc279 | |||
| 2d5f0787de | |||
| 5819bd8a59 | |||
| b3c2cc42cf | |||
| 167ef4512f | |||
| ba9cc5e972 | |||
| 05ae4467ae | |||
| 2b795265e7 | |||
| a57d05fe0a | |||
| ae9e8bc0bf |
@@ -0,0 +1,141 @@
|
||||
---
|
||||
name: blocking-io-guard
|
||||
description: Ensure async-path backend code that could block the asyncio event loop is protected by a teeth-verified runtime anchor in tests/blocking_io/. Use when changing backend Python under app/, packages/harness/deerflow/, or scripts/, when running a blocking-IO triage round over the whole repo, or when a reviewer/CI asks for blocking-IO coverage. Runs a deterministic scan (changed-lines or full-repo), routes each candidate, drafts/extends an anchor, and proves it fails when the blocking IO regresses.
|
||||
---
|
||||
|
||||
# Blocking-IO Guard Skill
|
||||
|
||||
Help a contributor ship backend async changes together with the runtime anchor
|
||||
that lets DeerFlow's blocking-IO CI gate actually see the new code. The dynamic
|
||||
detector only catches blocking IO on paths a test executes — this skill closes
|
||||
that gap, either for your own diff or for a repo-wide triage round.
|
||||
|
||||
Read `references/good-anchor-rules.md` before writing any anchor.
|
||||
Only read `references/sop-skeleton.md` when generalizing this SOP to another
|
||||
detector domain — it is not needed to execute the steps below.
|
||||
|
||||
## When to use
|
||||
|
||||
- Your change touches Python under `backend/app/`,
|
||||
`backend/packages/harness/deerflow/`, or `backend/scripts/` and may run on
|
||||
the async event loop (Mode A). If unsure, run Step 0 — it answers
|
||||
deterministically.
|
||||
- You are doing a maintenance triage round over the existing codebase
|
||||
(Mode B).
|
||||
|
||||
## SOP (router)
|
||||
|
||||
### Step 0 — Scope (deterministic)
|
||||
|
||||
**Mode A — your own diff** (default, pre-PR). From repo root:
|
||||
|
||||
```bash
|
||||
uv run --project backend python scripts/scan_changed_blocking_io.py --base origin/main
|
||||
```
|
||||
|
||||
Lists blocking-IO candidates your change introduces: findings on lines the
|
||||
diff added, **plus** findings that are new versus the merge base — the latter
|
||||
catches a new async caller exposing an old sync helper whose blocking line is
|
||||
not in the diff. The diff is `<base>...HEAD`, so **commit your work first** —
|
||||
uncommitted lines are not selected.
|
||||
|
||||
If the list is empty, this change introduces no blocking-IO surface *that the
|
||||
static detector can see in the changed files*. One residual blind spot
|
||||
remains: reachability is same-file only, so a new async caller of a sync
|
||||
helper **defined in another file** is invisible to both selections. If your
|
||||
diff adds an async call into a helper that lives elsewhere, check that helper
|
||||
manually (codegraph or `git grep`) before stopping.
|
||||
|
||||
**Mode B — full-repo triage round.** From repo root:
|
||||
|
||||
```bash
|
||||
make detect-blocking-io
|
||||
```
|
||||
|
||||
Prints a summary and writes the complete structured finding list to
|
||||
`.deer-flow/blocking-io-findings.json`. Work HIGH priority first; do not start
|
||||
MEDIUM until every HIGH is dispositioned (fixed, guarded, or recorded
|
||||
NO-ACTION).
|
||||
|
||||
**Batching policy (PR sizing).** One **fix unit** per PR while any HIGH
|
||||
remains: a fix unit is one root cause — usually a single HIGH, but two HIGHs
|
||||
resolved by the same one-place fix belong together. Once no HIGH remains,
|
||||
MEDIUM/LOW may be batched (about five per round, grouped by module or by
|
||||
disposition) so each PR stays reviewable. A new Blockbuster rule is never
|
||||
batched with anything — it always ships alone (see Step 5).
|
||||
|
||||
Both modes emit the same JSON shape per finding: `priority`, `location`
|
||||
(path/line/function), `blocking_call` (category/operation/symbol),
|
||||
`event_loop_exposure`, `reason`, `code`. Priority is a deterministic review
|
||||
ordering, not proof of a bug — Step 1 makes the actual call.
|
||||
|
||||
### Step 1 — Judge each candidate (router)
|
||||
|
||||
Read the code around each candidate and route it:
|
||||
|
||||
- **Already offloaded** (`asyncio.to_thread`, `run_in_executor`, async client) →
|
||||
**GUARD**: add/extend an anchor that locks the offload so a future edit cannot
|
||||
move it back onto the loop.
|
||||
- **On the loop, not offloaded** → **FIX+ANCHOR**: offload the production code
|
||||
(your fix), then add an anchor that guards it.
|
||||
- **Not actually exposed / acceptable** (rare: scanner false positive,
|
||||
startup-only code) → **NO-ACTION**: record one line of why.
|
||||
- **Cross-file caveat**: the scanner's async reachability is same-file only
|
||||
(`ASYNC_REACHABLE_SAME_FILE`). If the candidate is a *sync helper*, check for
|
||||
async callers in other files (codegraph or `git grep`) before deciding
|
||||
NO-ACTION.
|
||||
|
||||
### Step 2 — Apply the fix, then re-scan (FIX+ANCHOR only)
|
||||
|
||||
Offload the blocking call in production code, then re-run the Step 0 scan and
|
||||
confirm the candidate no longer appears. If the offloaded call sits in a
|
||||
`finally` / cleanup path, keep it best-effort and bounded (swallow-and-log,
|
||||
`asyncio.wait_for`) so a failing or hung cleanup cannot mask the primary
|
||||
exception. Match by the stable key
|
||||
**(path, function, symbol)** — line numbers shift after edits, so never
|
||||
compare by line.
|
||||
|
||||
- The finding must disappear. If it still shows, the fix did not remove the
|
||||
blocking pattern (e.g. the call is still a direct call, not offloaded) —
|
||||
go back before touching any test.
|
||||
- GUARD / NO-ACTION routes skip this step: a residual finding there is
|
||||
*expected* (the raw call still exists inside a sync helper with the offload
|
||||
at the caller, or the exposure was judged acceptable).
|
||||
|
||||
This is pattern-level feedback in seconds; it complements but never replaces
|
||||
Step 5 — only the runtime gate proves the event loop is actually protected.
|
||||
|
||||
### Step 3 — Check existing anchors
|
||||
|
||||
Look in `backend/tests/blocking_io/` for a test that drives the production async
|
||||
entry point reaching this candidate's branch.
|
||||
|
||||
- Covers this branch already → go to Step 5 (re-verify teeth).
|
||||
- Covers the entry point but not this branch (e.g. happy path covered,
|
||||
cleanup/404/409 not) → **extend** that anchor.
|
||||
- None → create one from `templates/anchor.template.py`.
|
||||
|
||||
### Step 4 — Generate / extend the anchor
|
||||
|
||||
Follow `references/good-anchor-rules.md`. Drive the *specific* branch (e.g. force
|
||||
the create failure that hits the cleanup `shutil.rmtree`). Never bypass the
|
||||
blocking surface with a test-only `asyncio.to_thread` wrapper.
|
||||
|
||||
### Step 5 — Verify teeth (mandatory; also the anchor-vs-rule discriminator)
|
||||
|
||||
1. Reintroduce the block (GUARD: temporarily revert the offload; FIX+ANCHOR: run
|
||||
against the pre-fix code).
|
||||
2. Run `cd backend && make test-blocking-io` (or target the one test). It **must
|
||||
go RED**.
|
||||
3. Restore the fix. It **must go GREEN**.
|
||||
|
||||
A real block that stays GREEN means Blockbuster has no rule for that
|
||||
primitive — that is the **RULE** route; see `references/good-anchor-rules.md`
|
||||
for the admission criteria before adding one.
|
||||
|
||||
### Step 6 — Deliver
|
||||
|
||||
Commit the anchor(s) with your change; `make test-blocking-io` green. In the PR,
|
||||
note: candidates found, each disposition, the re-scan result (Step 2), and
|
||||
the teeth evidence (red→green). Include the reason for any NO-ACTION. A new
|
||||
Blockbuster rule, if any, goes in its own commit with the evidence from Step 5.
|
||||
@@ -0,0 +1,65 @@
|
||||
# Good anchor rules + teeth (blocking-IO fill)
|
||||
|
||||
Distilled from `backend/docs/BLOCKING_IO_DETECTION.md`. An anchor lives in
|
||||
`backend/tests/blocking_io/`; the suite's conftest runs each test under the
|
||||
strict Blockbuster gate scoped to `app.*` / `deerflow.*`.
|
||||
|
||||
The examples in this file and in `templates/` are all filesystem-flavored.
|
||||
They demonstrate how to *write* the test, not what the SOP covers: the same
|
||||
rules apply to every category the detector reports (FILE_IO, HTTP,
|
||||
SUBPROCESS, SLEEP), and the acceptance criterion is always the teeth check
|
||||
below — never similarity to an example.
|
||||
|
||||
## A good anchor
|
||||
|
||||
- Calls the **real production async entry point** — not a low-level helper,
|
||||
unless that helper *is* the entry point production executes.
|
||||
- Does **not** bypass the blocking surface with a test-only
|
||||
`asyncio.to_thread` / `run_in_executor` wrapper.
|
||||
- Uses **real local filesystem** inputs when the bug shape is filesystem IO.
|
||||
- Mocks **only** the external dependency boundary (network service, third-party
|
||||
saver), never the offload being guarded.
|
||||
- Drives the **specific branch** you are protecting (error / cleanup / 404 /
|
||||
409), not just the happy path.
|
||||
|
||||
## Teeth (the acceptance test)
|
||||
|
||||
An anchor only counts if the gate actually fires when the code blocks:
|
||||
|
||||
1. Reintroduce the block (revert the offload, or run pre-fix code).
|
||||
2. `cd backend && make test-blocking-io` → the anchor **must fail** (RED).
|
||||
3. Restore the fix → the anchor **must pass** (GREEN).
|
||||
|
||||
A green-on-happy-path anchor with no proven red is fake coverage. Don't ship it.
|
||||
|
||||
## The RULE route (rare; strict admission criteria)
|
||||
|
||||
Blockbuster's built-in rules cover the common blocking primitives well. The
|
||||
two deliberate openings in this SOP are:
|
||||
|
||||
1. **Coverage opening** (the normal case): the rules already see the
|
||||
primitive — you only need an anchor so runtime detection executes the real
|
||||
business path and CI prevents regression.
|
||||
2. **Rule opening** (rare): you reintroduced a *real* block and the gate
|
||||
stayed GREEN — Blockbuster has no rule for that primitive.
|
||||
|
||||
A project rule lives in `_PROJECT_BLOCKING_RULES` inside
|
||||
`backend/tests/support/detectors/blocking_io_runtime.py` and changes detection
|
||||
for the **entire** blocking-IO suite — global blast radius. Admission criteria
|
||||
for adding one:
|
||||
|
||||
- You have the **fails-to-fail anchor** as evidence: a good anchor (per the
|
||||
rules above) that drives a genuinely blocking path and stays green. No
|
||||
evidence, no rule.
|
||||
- The primitive is a real blocking call (verified against its implementation
|
||||
or docs), not a false positive of the static detector.
|
||||
- The rule ships in its **own commit**, naming the primitive, the anchor that
|
||||
exposed the gap, and the suite-wide impact. Run the full
|
||||
`make test-blocking-io` suite after adding it — a new rule can turn other
|
||||
previously-green tests red, and each such red is either a real latent bug
|
||||
(fix it) or rule overreach (narrow the rule).
|
||||
- If you are not in a position to own that blast radius (e.g. external
|
||||
contributor), escalate to a maintainer with the evidence instead.
|
||||
|
||||
**Never add a runtime rule just because a path is untested** — that case needs
|
||||
an anchor, not a rule.
|
||||
@@ -0,0 +1,34 @@
|
||||
# SOP skeleton (generic shape — extraction seam)
|
||||
|
||||
This is the domain-agnostic shape the blocking-IO skill instantiates. It exists
|
||||
so a second detector/gate domain can reuse the flow without copying it. Do not
|
||||
add machinery for that until a second domain actually appears (YAGNI).
|
||||
|
||||
A domain provides:
|
||||
- a **static detector** that can scan a diff (or the whole tree) and emit
|
||||
located candidates,
|
||||
- a **CI gate** that fails when the bad pattern executes,
|
||||
- a **test location** for guard tests,
|
||||
- **good-test rules** for that gate,
|
||||
- a **teeth definition** (how to make the gate fire on purpose).
|
||||
|
||||
Steps:
|
||||
1. **Scope (deterministic):** intersect the diff's added lines with the
|
||||
detector's findings → candidates this change introduced/touched. (Or, in
|
||||
triage mode, take the full finding list ordered by priority.)
|
||||
2. **Judge (router):** per candidate — guard existing fix / fix + guard /
|
||||
no-action / rule (the gate cannot see the primitive).
|
||||
3. **Fix + re-scope (fixes only):** apply the fix, re-run the detector; the
|
||||
fixed candidate must vanish from the findings (match by a stable key, not
|
||||
line numbers). Pattern-level feedback in seconds — complements, never
|
||||
replaces, step 5.
|
||||
4. **Generate:** draft or extend a guard test per the good-test rules, driving
|
||||
the specific branch.
|
||||
5. **Verify teeth:** make the bad pattern happen → gate must fail; restore →
|
||||
gate must pass. A pattern that stays green while genuinely bad is the
|
||||
"rule" signal, not a coverage success.
|
||||
6. **Deliver:** commit the verified guard test; any gate-rule change ships in
|
||||
its own commit with the fails-to-fail evidence attached.
|
||||
|
||||
To add a domain: supply a new fill doc (like `good-anchor-rules.md`) + detector,
|
||||
and promote this file into a parent skill the instances point at.
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Template: a tests/blocking_io/ runtime anchor.
|
||||
|
||||
Copy into backend/tests/blocking_io/test_<area>.py and adapt. The suite's
|
||||
conftest already wraps every test here in the strict Blockbuster gate, so you do
|
||||
NOT import or activate the detector — just drive the real async entry point.
|
||||
|
||||
Teeth check before you commit (see references/good-anchor-rules.md):
|
||||
1. reintroduce the block -> `cd backend && make test-blocking-io` must FAIL
|
||||
2. restore the fix -> it must PASS
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# from app.<module> import <real_async_entry_point>
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def test_<entry_point>_offloads_blocking_io_on_<branch>(tmp_path: Path) -> None:
|
||||
# Arrange: real inputs at the boundary the code blocks on (FS -> tmp_path;
|
||||
# HTTP/subprocess -> stub the external service). Mock ONLY the external
|
||||
# boundary, never the offload under test.
|
||||
|
||||
# Act + Assert: call the REAL production async entry point and drive the
|
||||
# specific branch you are guarding (e.g. force a failure to hit the cleanup
|
||||
# path). If the entry point performs blocking IO on the loop, the gate fails.
|
||||
# await <real_async_entry_point>(...)
|
||||
raise NotImplementedError("Replace with the real async entry point call.")
|
||||
@@ -0,0 +1,237 @@
|
||||
---
|
||||
name: deerflow-maintainer-orchestrator
|
||||
description: "Use when a DeerFlow maintainer needs comment-only GitHub issue or PR handling: resolve issue/PR scopes with gh, analyze issues, post or draft issue comments, perform PR review comments, give fix strategy, risk classification, and validation guidance. Intended for maintainers and trusted local agents, not general contributors."
|
||||
---
|
||||
|
||||
# DeerFlow Maintainer Orchestrator
|
||||
|
||||
## Core Rule
|
||||
|
||||
This is a comment-plane skill: resolve GitHub scope, inspect evidence, and prepare or post DeerFlow issue comments and PR review comments. Keep the work comment-scoped; do not turn it into coding, branch management, release work, artifact closure, or other maintainer operations.
|
||||
|
||||
When the maintainer asks to process, handle, comment on, or review a bounded set of issues or PRs, proceed without asking follow-up questions. Treat that request as authorization for one public issue comment per selected non-skipped issue and one PR review comment per selected PR with high-confidence findings. If a PR has no high-confidence findings, do not post a public comment; report that result to the maintainer only. If the maintainer explicitly asks for analysis only, return comment-ready drafts without posting.
|
||||
|
||||
The maintainer's normal interaction should be: provide scope; receive posted comment URLs, PR review URLs, clean results, skipped items, failures, or drafts. Do not offload technical analysis to the maintainer. Make the best evidence-backed recommendation in the comment itself: describe the risk, impact, likely fix, and validation path. Ask the reporter or PR author for missing evidence only when the artifact lacks enough data to diagnose.
|
||||
|
||||
Output only the maintainer run result or comment draft. Do not announce the skill name, mode, or that no code was edited unless the user asks for process details.
|
||||
|
||||
Match the dominant language of the issue or PR unless the maintainer asks for another language. Chinese issue or PR text gets Chinese output; English issue or PR text gets English output. For mixed artifacts, use the body language, not logs or code.
|
||||
|
||||
## Artifact Resolution
|
||||
|
||||
Use GitHub tooling to resolve artifact type and scope. Do not ask the maintainer to clarify when `gh` or GitHub API can determine the answer.
|
||||
|
||||
1. Default repository is `bytedance/deer-flow` unless a URL or explicit repo says otherwise.
|
||||
2. For URLs, route `/issues/<number>` to Issue Flow and `/pull/<number>` to PR Review Flow.
|
||||
3. For typed numbers, use the typed command:
|
||||
- Issue: `gh issue view <number> --repo <repo> --json number,title,url,state,body,labels,author,comments`
|
||||
- PR: `gh pr view <number> --repo <repo> --json number,title,url,state,body,author,files,comments,reviews,statusCheckRollup,baseRefName,headRefName`
|
||||
4. Normalize multiple explicit references such as `#123`, `# 123`, and bare `123` into a number list, preserving order and de-duplicating exact repeats.
|
||||
5. For untyped numbers, try `gh pr view <number> --repo <repo> --json number,url` first. If it fails, use `gh issue view <number> --repo <repo> --json number,url`. Do not ask which type it is.
|
||||
6. For issue batches, use `gh issue list`, not the mixed GitHub issues endpoint. For PR batches, use `gh pr list`.
|
||||
7. Respect maintainer-provided count or time window. There is no hard five-item cap. If the scope is broad and underspecified, choose a practical recent slice, state the slice used, prioritize newest and highest-risk items, and report any unprocessed remainder.
|
||||
8. For "recent/latest" wording without a count, use a small default recent slice. For "recent hours" wording without a number, use six hours. Do not ask.
|
||||
9. Use `gh api` when `gh issue/pr view/list` lacks required fields such as timeline events, review threads, or precise search filters.
|
||||
10. Use GitHub search only as a fallback for natural-language filters that cannot be represented by view/list/API calls. Do not use web search for artifact routing unless GitHub tooling is unavailable.
|
||||
11. If no artifact type, number, URL, count, time window, or searchable GitHub scope can be resolved, stop with a compact "scope unresolved" report. Do not ask a follow-up question.
|
||||
|
||||
Use concise repo-local references such as `#123` and `PR #123` in maintainer reports and comments. Include full GitHub URLs only for posted comment/review links returned by GitHub or when the maintainer supplied an explicit URL.
|
||||
|
||||
## Issue Flow
|
||||
|
||||
Use Issue Flow for GitHub issues, bug reports, feature requests, support questions, and issue batches.
|
||||
|
||||
Start every issue with a cheap duplicate-opinion precheck:
|
||||
|
||||
1. Fetch issue metadata, labels, author, body, and existing comments.
|
||||
2. If labels, title, or body mark the issue as RFC (`rfc`, `[RFC]`, `RFC:`, or `Request for Comments`), classify it as `rfc-no-comment`, skip deep analysis, and do not post anything public unless the maintainer explicitly overrides the RFC skip for that item.
|
||||
3. If an existing maintainer or trusted-agent issue comment already gives a materially equivalent diagnosis, modification suggestion, information request, or blocking decision, skip deep analysis and do not post anything public for that issue.
|
||||
4. Treat ordinary reporter replies, thanks, unrelated discussion, or incomplete guesses as non-blocking.
|
||||
5. Report skipped issues to the maintainer only as compact identifiers plus the skipped reason or existing comment URL when available.
|
||||
|
||||
For non-skipped issues:
|
||||
|
||||
1. Read enough context to avoid guessing: issue body, comments, screenshots, logs, reproduction details, linked artifacts, and relevant DeerFlow code/docs.
|
||||
2. Classify the surface:
|
||||
- Frontend UI
|
||||
- Backend API
|
||||
- Agents / LangGraph
|
||||
- Sandbox
|
||||
- Skills
|
||||
- MCP
|
||||
- Dependencies
|
||||
- Default behavior
|
||||
- Docs / tests / CI only
|
||||
3. Classify actionability:
|
||||
- `ready-to-fix`: bounded, evidence sufficient, validation path clear.
|
||||
- `needs-more-evidence`: repro, logs, environment, screenshots, exact expected behavior, or failing case missing.
|
||||
- `defer-or-close`: duplicate, stale, unsupported, unactionable, or out of scope.
|
||||
- `rfc-no-comment`: RFC issue; skip public comments by default.
|
||||
4. Produce a public-safe comment from the analysis, not the analysis labels:
|
||||
- Start with one natural opener that connects to the issue context. Prefer `Thanks @author.` for reporter-authored issues when it reads naturally; omit the mention for bots, maintainer-authored tracking issues, or cases where it would add noise.
|
||||
- The opener must say something specific about the next step or boundary, not a generic assessment. Do not use generic phrases such as "This is actionable", "I would treat this as", "ready to fix", or surface/actionability/risk labels.
|
||||
- Use the smallest stable template that fits:
|
||||
|
||||
```text
|
||||
Thanks @author. <one specific sentence that frames the fix, investigation, or missing evidence.>
|
||||
|
||||
Recommended solution:
|
||||
- ...
|
||||
|
||||
Validation:
|
||||
- ...
|
||||
```
|
||||
|
||||
- Add `Evidence:` only when citing concrete code, logs, reproduction details, or other proof helps the author act.
|
||||
- Add `Risk:` only when architecture, security, public API, default behavior, or compatibility impact must be called out explicitly; make the risk specific.
|
||||
- Add `Missing info:` only when the issue cannot be diagnosed without more evidence; ask for the smallest useful data.
|
||||
- Put relevant files/components inside `Evidence:` or `Recommended solution:` bullets instead of separate metadata fields.
|
||||
- Every posted issue comment should contain concrete modification guidance and validation guidance unless the only useful response is `Missing info:`.
|
||||
5. Immediately before posting, refresh comments and skip if an equivalent maintainer or trusted-agent comment appeared during analysis.
|
||||
6. Post one issue comment when posting is authorized; otherwise return the same text as `Reply draft`.
|
||||
|
||||
Do not expose private reasoning, credentials, internal-only context, or unsupported promises. Do not say a fix was made unless a separate coding workflow actually changed code.
|
||||
|
||||
## PR Review Flow
|
||||
|
||||
Use PR Review Flow for GitHub pull requests and PR batches.
|
||||
|
||||
Start every PR with a cheap duplicate-review precheck:
|
||||
|
||||
1. Fetch PR metadata, changed file list, checks summary, existing PR reviews, existing PR comments, and review threads when available.
|
||||
2. If an existing maintainer or trusted-agent review already gives materially equivalent findings or a blocking decision, skip deep review and do not post anything public for that PR.
|
||||
3. Treat author replies, thanks, unrelated discussion, or incomplete guesses as non-blocking.
|
||||
4. Report skipped PRs to the maintainer only as compact identifiers plus the existing review/comment URL when available.
|
||||
|
||||
### Diff Base Rule
|
||||
|
||||
Before reviewing a local PR branch or local diff, fetch the base repository's target branch and compare against that fresh remote-tracking ref, not a possibly stale local `main`.
|
||||
|
||||
- For fork checkouts, prefer `upstream/<base-branch>` when `upstream` points to the base repository.
|
||||
- For direct upstream checkouts, use the base remote's fetched branch, usually `origin/<base-branch>`.
|
||||
- Prefer GitHub PR base metadata for the target branch. For non-PR local diffs, use the base repository default branch. If metadata is unavailable, default to `main` only after fetching the base remote.
|
||||
- Refresh the comparison ref explicitly, for example `git fetch <base-remote> +refs/heads/<base-branch>:refs/remotes/<base-remote>/<base-branch>`, then inspect `BASE=$(git merge-base HEAD <base-remote>/<base-branch>)` and `git diff "$BASE"...HEAD`.
|
||||
- If using `FETCH_HEAD` from a single-branch fetch instead, diff against that verified `FETCH_HEAD` immediately and do not later substitute a possibly stale remote-tracking ref.
|
||||
- For uncommitted local changes, review committed branch changes against the fresh base first, then include working-tree changes separately.
|
||||
- If the base remote or base branch cannot be established, use the GitHub PR files/diff as the source of truth. If neither local nor GitHub diff can be read, return a compact failure report and do not post a review.
|
||||
|
||||
Before posting a PR review comment:
|
||||
|
||||
1. Review only the current diff against the fresh base and changed files. Do not comment on unrelated pre-existing code unless the diff makes it newly risky.
|
||||
2. Do not report low-confidence guesses. If evidence is insufficient, omit the finding.
|
||||
3. Prioritize correctness, safety, maintainability, production risk, compatibility, and missing critical tests over style.
|
||||
4. Report concrete architecture, security, public API, default-behavior, and compatibility problems as findings when the diff causes or exposes them.
|
||||
5. Check changed behavior, edge cases, error paths, state mutation, transactions, locks, cache invalidation, cleanup, security boundaries, missing tests, performance/reliability, and API compatibility.
|
||||
6. Immediately before posting, refresh reviews/comments and skip if an equivalent maintainer or trusted-agent review appeared during analysis.
|
||||
7. If there are high-confidence findings, post a PR review comment using the PR language. If there are no high-confidence findings, do not post a public PR review/comment; report `No high-confidence review findings.` to the maintainer in the run result.
|
||||
|
||||
For public PR reviews with findings, start with one short opener that fits the review context and matches the finding count. Use singular wording only for exactly one finding, for example `Thanks @author. I found one issue that should be addressed before this is ready.` Use plural wording for multiple findings, for example `Thanks @author. I found a few issues that should be addressed before this is ready.` Omit the mention for bots or when it adds noise.
|
||||
|
||||
For each finding, use:
|
||||
|
||||
```text
|
||||
[P0/P1/P2] Title
|
||||
|
||||
- Location: file and line/range
|
||||
- Problem: what can go wrong
|
||||
- Evidence: why the diff causes it
|
||||
- Suggested fix: concrete minimal fix
|
||||
- Test: what test should cover it
|
||||
```
|
||||
|
||||
Severity guide:
|
||||
|
||||
- `P0`: causes outage, data loss, security breach, or build failure.
|
||||
- `P1`: likely production bug, serious regression, broken compatibility, or high-risk security/architecture issue.
|
||||
- `P2`: correctness, maintainability, or test concern with lower risk.
|
||||
|
||||
Do not produce compliments, summaries, or general advice. For sensitive security issues, describe impact and remediation without exploit instructions.
|
||||
|
||||
## No-Question Policy
|
||||
|
||||
Do not ask the maintainer routine clarification questions. The skill should save maintainer time by turning scope into comments through a fixed workflow.
|
||||
|
||||
Stop without asking only when:
|
||||
|
||||
- no issue/PR scope can be resolved through URLs, numbers, `gh` view/list, `gh api`, or GitHub search fallback;
|
||||
- GitHub authentication, repository access, or comment posting fails;
|
||||
- the requested action is outside comment-only scope;
|
||||
- posting would require private credentials, private security details, or non-public context.
|
||||
|
||||
In these cases, return a compact failure report with the attempted command path and the smallest next action. Do not phrase it as a question unless the maintainer explicitly asked to be prompted.
|
||||
|
||||
## DeerFlow Review Heuristics
|
||||
|
||||
Treat these as high-signal areas for issue comments and PR findings:
|
||||
|
||||
- `backend/packages/harness/deerflow/` must not import `app.*`.
|
||||
- App may depend on harness; harness must stay publishable and app-agnostic.
|
||||
- Frontend thread/message behavior and Gateway/LangGraph-compatible SSE are contract surfaces.
|
||||
- Sandbox permissions, bash/file-write tools, skill installation, and remote execution are security-sensitive.
|
||||
- Default model/provider behavior, config migration, persistence schema, public API/SSE, and LangGraph thread/run lifecycle are compatibility-sensitive.
|
||||
- Runtime docs should track user-facing or developer-facing behavior changes.
|
||||
- Security-sensitive comments should provide proof and remediation, not vague assertions.
|
||||
|
||||
## Validation Guidance
|
||||
|
||||
Recommend the checks matching the touched surface:
|
||||
|
||||
| Surface | Suggested validation |
|
||||
| --- | --- |
|
||||
| Backend API / harness / agents / MCP / skills runtime | `cd backend && make lint && make test` |
|
||||
| Blocking IO or async file/network work | `cd backend && make test-blocking-io` or a focused blocking-IO regression |
|
||||
| Harness/app boundary | `cd backend && uv run pytest tests/test_harness_boundary.py` |
|
||||
| Frontend UI/core | `cd frontend && pnpm format && pnpm lint && pnpm typecheck && BETTER_AUTH_SECRET=local-dev-secret pnpm build && make test` |
|
||||
| Front/back thread or SSE contract | backend replay golden and full-stack replay render where feasible |
|
||||
| Frontend user workflow | Playwright E2E or browser proof with screenshot/DOM assertion |
|
||||
| Docker/sandbox/provisioner | focused backend tests plus Docker/provisioner smoke when feasible |
|
||||
| Docs-only | targeted markdown review |
|
||||
|
||||
## Output
|
||||
|
||||
For Issue Flow:
|
||||
|
||||
```text
|
||||
Run result:
|
||||
Posted:
|
||||
Skipped:
|
||||
Failed:
|
||||
Per issue:
|
||||
Issue:
|
||||
Surface:
|
||||
Actionability:
|
||||
Risk:
|
||||
Comment:
|
||||
Validation:
|
||||
Comment status:
|
||||
```
|
||||
|
||||
For PR Review Flow:
|
||||
|
||||
```text
|
||||
Run result:
|
||||
Reviewed:
|
||||
Skipped:
|
||||
Clean:
|
||||
Failed:
|
||||
Per PR:
|
||||
PR:
|
||||
Public review:
|
||||
Findings:
|
||||
Review status:
|
||||
```
|
||||
|
||||
For analysis-only requests, replace `Posted`/`Reviewed` with `Drafted` and include the comment/review text without posting.
|
||||
|
||||
For batches, prefer a compact maintainer-facing table after the headline counts:
|
||||
|
||||
```text
|
||||
| Artifact | Status | Public action | Notes |
|
||||
| --- | --- | --- | --- |
|
||||
| #123 | posted | comment URL | short reason |
|
||||
| PR #456 | reviewed | review URL | P1: finding title |
|
||||
| PR #789 | clean | none | No high-confidence review findings. |
|
||||
| #321 | skipped | none | existing maintainer comment |
|
||||
```
|
||||
|
||||
Omit empty categories, no-op fields, routine command output, and raw logs. Report meaningful changes, evidence, and options.
|
||||
@@ -66,3 +66,18 @@ INFOQUEST_API_KEY=your-infoquest-api-key
|
||||
# alias, or behind a different port). docker-compose already sets these.
|
||||
# DEER_FLOW_INTERNAL_GATEWAY_BASE_URL=http://localhost:8001
|
||||
# DEER_FLOW_TRUSTED_ORIGINS=http://localhost:3000,http://localhost:2026
|
||||
|
||||
# ── Claude Code / Codex CLI subscription as a model provider (optional) ───────
|
||||
# If you configure a ClaudeChatModel / Codex model provider (or an ACP agent)
|
||||
# that reuses your CLI subscription login, prefer passing a token via env over
|
||||
# bind-mounting your whole ~/.claude / ~/.codex into the container. The Gateway
|
||||
# credential loader reads these first, so no directory mount is needed.
|
||||
# CLAUDE_CODE_CREDENTIALS_PATH points at a single .credentials.json (Claude)
|
||||
# rather than the whole dir. docker-compose.cli-auth.yaml is the opt-in
|
||||
# directory-mount fallback for adapters that need the full CLI config.
|
||||
# ACP adapters often take their own env API key (e.g. ANTHROPIC_API_KEY) and
|
||||
# need no mount at all — check the adapter's docs. See SECURITY.md.
|
||||
# CLAUDE_CODE_OAUTH_TOKEN=your-claude-code-oauth-token
|
||||
# ANTHROPIC_AUTH_TOKEN=your-anthropic-auth-token
|
||||
# CLAUDE_CODE_CREDENTIALS_PATH=/path/to/.claude/.credentials.json
|
||||
# CODEX_AUTH_PATH=/path/to/codex/auth.json
|
||||
|
||||
@@ -10,7 +10,7 @@ permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
lint-backend:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
@@ -76,7 +76,8 @@ install:
|
||||
@echo "Installing frontend dependencies..."
|
||||
@cd frontend && pnpm install
|
||||
@echo "Installing pre-commit hooks..."
|
||||
@$(BACKEND_UV_RUN) --with pre-commit pre-commit install
|
||||
@uv tool install pre-commit
|
||||
@pre-commit install --overwrite
|
||||
@echo "✓ All dependencies installed"
|
||||
@echo ""
|
||||
@echo "=========================================="
|
||||
|
||||
@@ -247,6 +247,9 @@ Access: http://localhost:2026
|
||||
|
||||
The unified nginx endpoint is same-origin by default and does not emit browser CORS headers. If you run a split-origin or port-forwarded browser client, set `GATEWAY_CORS_ORIGINS` to comma-separated exact origins such as `http://localhost:3000`; the Gateway then applies the CORS allowlist and matching CSRF origin checks.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The Gateway holds run state (RunManager and the stream bridge) in process, so production defaults to a single Gateway worker (`GATEWAY_WORKERS=1`). Raising the worker count without a shared cross-worker stream bridge — which is not yet available — breaks run cancellation, SSE reconnects, request de-duplication, and IM channels, because nginx uses no sticky sessions and each worker keeps its own run state. Scale a single worker up with more CPU/RAM (or move the database and sandbox onto dedicated tiers) instead of raising `GATEWAY_WORKERS`.
|
||||
|
||||
See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
|
||||
|
||||
#### Option 2: Local Development
|
||||
@@ -340,6 +343,8 @@ See the [MCP Server Guide](backend/docs/MCP_SERVER.md) for detailed instructions
|
||||
|
||||
DeerFlow supports receiving tasks from messaging apps. Channels auto-start when configured — no public IP required for any of them.
|
||||
|
||||
DeerFlow can also expose user-owned IM channel connections in the workspace UI. When `channel_connections` is enabled, logged-in users can bind Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, or WeCom from the sidebar / Settings > Channels. It reuses the existing outbound `channels.*` transports, so no public IP or provider callback URL is required. Incoming IM messages then run under the connected DeerFlow user account. See [IM Channel Connections](backend/docs/IM_CHANNEL_CONNECTIONS.md) for setup and security notes.
|
||||
|
||||
| Channel | Transport | Difficulty |
|
||||
|---------|-----------|------------|
|
||||
| Telegram | Bot API (long-polling) | Easy |
|
||||
|
||||
+43
-13
@@ -112,6 +112,14 @@ calls are resolved by function name, so duplicate helper names in one file can
|
||||
conservatively over-report async reachability. It is intentionally
|
||||
informational and is not run from CI in this round.
|
||||
|
||||
For a diff-scoped view of the same findings, `scripts/scan_changed_blocking_io.py`
|
||||
(repo root) reports findings on the added lines of `git diff <base>...HEAD`
|
||||
plus findings new versus the merge base (so a new async caller exposing an
|
||||
untouched sync helper in the same file is still reported) — used by the
|
||||
`blocking-io-guard` skill (`.agent/skills/blocking-io-guard/`) as the
|
||||
deterministic scope step before routing each candidate to a fix and/or a
|
||||
`tests/blocking_io/` runtime anchor.
|
||||
|
||||
Regression tests related to Docker/provisioner behavior:
|
||||
- `tests/test_docker_sandbox_mode_detection.py` (mode detection from `config.yaml`)
|
||||
- `tests/test_provisioner_kubeconfig.py` (kubeconfig file/directory handling)
|
||||
@@ -226,7 +234,7 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc
|
||||
|
||||
**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state` — `lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work and passes it to `langgraph_runtime(app, startup_config)`.
|
||||
|
||||
Infrastructure fields are **restart-required**. The authoritative list lives in `packages/harness/deerflow/config/reload_boundary.py::STARTUP_ONLY_FIELDS` and is mirrored by the standardised `"startup-only:"` prefix on the corresponding `Field(description=...)` in `AppConfig`, so IDE hover on those fields surfaces the reason inline (no need to context-switch into this table). Currently registered: `database`, `checkpointer`, `run_events`, `stream_bridge`, `sandbox`, `log_level`, `channels`. Adding a new restart-required field requires updating the registry; drift is pinned by `tests/test_reload_boundary.py`.
|
||||
Infrastructure fields are **restart-required**. The authoritative list lives in `packages/harness/deerflow/config/reload_boundary.py::STARTUP_ONLY_FIELDS` and is mirrored by the standardised `"startup-only:"` prefix on the corresponding `Field(description=...)` in `AppConfig`, so IDE hover on those fields surfaces the reason inline (no need to context-switch into this table). Currently registered: `database`, `checkpointer`, `run_events`, `stream_bridge`, `sandbox`, `log_level`, `channels`, `channel_connections`. Adding a new restart-required field requires updating the registry; drift is pinned by `tests/test_reload_boundary.py`.
|
||||
|
||||
Configuration priority:
|
||||
1. Explicit `config_path` argument
|
||||
@@ -284,7 +292,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
||||
**Provider Pattern**: `SandboxProvider` with `acquire`, `acquire_async`, `get`, `release` lifecycle. Async agent/tool paths call async sandbox lifecycle hooks so Docker sandbox creation, discovery, cross-process locking, readiness polling, and release stay off the event loop.
|
||||
**Implementations**:
|
||||
- `LocalSandboxProvider` - Local filesystem execution. `acquire(thread_id)` returns a per-thread `LocalSandbox` (id `local:{thread_id}`) whose `path_mappings` resolve `/mnt/user-data/{workspace,uploads,outputs}` and `/mnt/acp-workspace` to that thread's host directories, so the public `Sandbox` API honours the `/mnt/user-data` contract uniformly with AIO. `acquire()` / `acquire(None)` keeps the legacy generic singleton (id `local`) for callers without a thread context. Per-thread sandboxes are held in an LRU cache (default 256 entries) guarded by a `threading.Lock`.
|
||||
- `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation
|
||||
- `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation. Active-cache and warm-pool entries are checked with the backend during acquire/reuse; definitively dead containers are dropped from all in-process maps so the thread can discover or create a fresh sandbox instead of reusing a stale client. Backend health-check failures are treated as unknown, not dead; local discovery likewise treats an unverifiable container as not adoptable and falls through to create rather than failing acquire. `get()` remains an in-memory lookup for event-loop-safe tool paths.
|
||||
|
||||
**Virtual Path System**:
|
||||
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
|
||||
@@ -307,6 +315,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
||||
**Flow**: `task()` tool → `SubagentExecutor` → background thread → poll 5s → SSE events → result
|
||||
**Events**: `task_started`, `task_running`, `task_completed`/`task_failed`/`task_timed_out`
|
||||
**Deferred MCP tools** (if `tool_search.enabled`): `SubagentExecutor._build_initial_state` assembles deferral after policy filtering via the shared `assemble_deferred_tools` (fail-closed), appends the `tool_search` tool, injects the `<available-deferred-tools>` section into the subagent's `SystemMessage`, and threads the setup to `_create_agent`, which attaches `DeferredToolFilterMiddleware` through `build_subagent_runtime_middlewares(deferred_setup=...)`. Subagents thus withhold full MCP schemas until promotion, same as the lead agent; each task run gets a fresh `ThreadState` so promotion is isolated per run
|
||||
**Checkpointer isolation**: Subagent graphs are compiled with `checkpointer=False` to avoid inheriting the parent run's checkpointer, since subagents are one-shot and never resume.
|
||||
|
||||
### Tool System (`packages/harness/deerflow/tools/`)
|
||||
|
||||
@@ -369,29 +378,32 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
||||
|
||||
### IM Channels System (`app/channels/`)
|
||||
|
||||
Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the DeerFlow agent via Gateway's LangGraph-compatible API.
|
||||
|
||||
Bridges external messaging platforms (Feishu, Slack, Telegram, Discord, DingTalk) to the DeerFlow agent via Gateway's LangGraph-compatible API.
|
||||
|
||||
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies.
|
||||
|
||||
**Components**:
|
||||
- `message_bus.py` - Async pub/sub hub (`InboundMessage` → queue → dispatcher; `OutboundMessage` → callbacks → channels)
|
||||
- `store.py` - JSON-file persistence mapping `channel_name:chat_id[:topic_id]` → `thread_id` (keys are `channel:chat` for root conversations and `channel:chat:topic` for threaded conversations)
|
||||
- `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Telegram on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu incremental outbound updates
|
||||
- `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Discord on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu/Telegram incremental outbound updates
|
||||
- `base.py` - Abstract `Channel` base class (start/stop/send lifecycle)
|
||||
- `service.py` - Manages lifecycle of all configured channels from `config.yaml`
|
||||
- `slack.py` / `feishu.py` / `telegram.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured)
|
||||
- `slack.py` / `feishu.py` / `telegram.py` / `discord.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `telegram.py` registers the "Working on it..." placeholder as the stream target and edits it in place via `editMessageText`; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured)
|
||||
- `app/gateway/routers/channel_connections.py` - Browser-facing user connection and disconnect APIs
|
||||
- `deerflow.persistence.channel_connections` - SQL-backed user-owned connection, optional credential, connect state, and conversation store
|
||||
|
||||
**Message Flow**:
|
||||
1. External platform -> Channel impl -> `MessageBus.publish_inbound()`
|
||||
2. `ChannelManager._dispatch_loop()` consumes from queue
|
||||
3. For chat: look up/create thread through Gateway's LangGraph-compatible API
|
||||
4. Feishu chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`)
|
||||
5. Slack/Telegram chat: `runs.wait()` → extract final response → publish outbound
|
||||
6. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement)
|
||||
7. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails
|
||||
8. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
|
||||
9. Outbound → channel callbacks → platform reply
|
||||
3. For user-owned channel connections, incoming messages carry `connection_id`, `owner_user_id`, and `workspace_id`; `owner_user_id` becomes the DeerFlow run `user_id`, while the raw platform user id remains `channel_user_id`
|
||||
4. For chat: look up/create thread through Gateway's LangGraph-compatible API
|
||||
5. Feishu/Telegram chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`)
|
||||
6. Slack/Discord chat: `runs.wait()` → extract final response → publish outbound
|
||||
7. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement)
|
||||
8. Telegram streaming: the "Working on it..." placeholder message is registered as the stream target; non-final updates `editMessageText` it in place (channel-side throttle: 1s in private chats, 3s in groups due to Telegram's 20 msg/min group cap; 4096-char truncation; rate-limited updates dropped); the final update performs the last edit and splits >4096 texts into follow-up messages
|
||||
9. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails
|
||||
10. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
|
||||
11. Outbound → channel callbacks → platform reply
|
||||
|
||||
**Configuration** (`config.yaml` -> `channels`):
|
||||
- `langgraph_url` - LangGraph-compatible Gateway API base URL (default: `http://localhost:8001/api`)
|
||||
@@ -399,6 +411,17 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
|
||||
- In Docker Compose, IM channels run inside the `gateway` container, so `localhost` points back to that container. Use `http://gateway:8001/api` for `langgraph_url` and `http://gateway:8001` for `gateway_url`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` / `DEER_FLOW_CHANNELS_GATEWAY_URL`.
|
||||
- Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token), `dingtalk` (client_id, client_secret, optional `card_template_id` for AI Card streaming)
|
||||
|
||||
**User-owned channel connections** (`config.yaml` -> `channel_connections`):
|
||||
- Disabled by default. It is a user-binding layer on top of the existing `channels.*` runtime config, not a replacement for provider bot credentials.
|
||||
- No public IP, OAuth callback URL, or provider webhook route is required by the current implementation.
|
||||
- Telegram uses a deep-link `/start <code>` flow over the existing long-polling worker. Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom use `/connect <code>` over their existing outbound channel workers.
|
||||
- Frontend APIs: `GET /api/channels/providers`, `GET /api/channels/connections`, `POST /api/channels/{provider}/connect`, and `DELETE /api/channels/connections/{connection_id}`.
|
||||
- Browser APIs remain protected by normal Gateway auth/CSRF. Provider messages arrive through the already-configured channel workers.
|
||||
- Provider-level `connection_status` reflects the user's newest connection row. With no binding it is `not_connected`, except in auth-disabled local mode where a configured running channel reports `connected` because all channel messages already route to the default user.
|
||||
- Slack replies use the configured operator bot token from `channels.slack` unless per-connection credentials are present; unreadable or corrupt stored credentials are treated as unavailable.
|
||||
- Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom workers resolve incoming platform identities to connection records before reaching `ChannelManager`.
|
||||
- See `backend/docs/IM_CHANNEL_CONNECTIONS.md` for provider setup and operational notes.
|
||||
|
||||
|
||||
### Memory System (`packages/harness/deerflow/agents/memory/`)
|
||||
|
||||
@@ -429,6 +452,12 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
|
||||
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
|
||||
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
|
||||
|
||||
**Token counting** (`packages/harness/deerflow/agents/memory/prompt.py`):
|
||||
- `_count_tokens` budgets the injection. In default `tiktoken` mode, the encoding is loaded lazily and cached.
|
||||
- Failed tiktoken loads are cached with a timestamp. During the fixed cooldown (`_TIKTOKEN_RETRY_COOLDOWN_S`, 600s), callers fall back to char estimation immediately instead of re-triggering the blocking BPE download; after the cooldown, transient outages can self-heal without a restart.
|
||||
- In-flight loads are cached as a LOADING sentinel so concurrent callers fall back instead of spawning more blocking threads.
|
||||
- Set `memory.token_counting: char` to skip tiktoken entirely and use the network-free CJK-aware char estimate.
|
||||
|
||||
Focused regression coverage for the updater lives in `backend/tests/test_memory_updater.py`.
|
||||
|
||||
**Configuration** (`config.yaml` → `memory`):
|
||||
@@ -438,6 +467,7 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_
|
||||
- `model_name` - LLM for updates (null = default model)
|
||||
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
|
||||
- `max_injection_tokens` - Token limit for prompt injection (2000)
|
||||
- `token_counting` - Token counting strategy for the injection budget: `tiktoken` (default, accurate but may download BPE data from a public endpoint on first use — can block for a long time in network-restricted environments, see issues #3402/#3429) or `char` (network-free CJK-aware char estimate, never touches tiktoken)
|
||||
|
||||
### Reflection System (`packages/harness/deerflow/reflection/`)
|
||||
|
||||
|
||||
+1
-1
@@ -69,7 +69,7 @@ Middlewares execute in strict order, each handling a specific concern:
|
||||
Per-thread isolated execution with virtual path translation:
|
||||
|
||||
- **Abstract interface**: `execute_command`, `read_file`, `write_file`, `list_dir`
|
||||
- **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/). Async runtime paths use async sandbox lifecycle hooks so startup, readiness polling, and release do not block the event loop.
|
||||
- **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/). Async runtime paths use async sandbox lifecycle hooks so startup, readiness polling, and release do not block the event loop. `AioSandboxProvider` validates active-cache and warm-pool containers during acquire/reuse, dropping definitively dead entries so a thread can provision a fresh sandbox after an unexpected container exit while keeping `get()` as an in-memory lookup. Backend health-check failures are treated as unknown, not dead, and a container that cannot be verified during discovery is simply not adopted (acquire falls through to create instead of failing).
|
||||
- **Virtual paths**: `/mnt/user-data/{workspace,uploads,outputs}` → thread-specific physical directories
|
||||
- **Skills path**: `/mnt/skills` → `deer-flow/skills/` directory
|
||||
- **Skills loading**: Recursively discovers nested `SKILL.md` files under `skills/{public,custom}` and preserves nested container paths
|
||||
|
||||
@@ -20,6 +20,17 @@ KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset(
|
||||
)
|
||||
|
||||
|
||||
def extract_connect_code(text: str) -> str | None:
|
||||
"""Extract the one-time channel binding code from a connect command."""
|
||||
parts = text.strip().split()
|
||||
if len(parts) < 2:
|
||||
return None
|
||||
command = parts[0].lower()
|
||||
if command in {"/connect", "connect"}:
|
||||
return parts[1]
|
||||
return None
|
||||
|
||||
|
||||
def is_known_channel_command(text: str) -> bool:
|
||||
"""Return whether text starts with a registered channel control command."""
|
||||
if not text.startswith("/"):
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
"""Helpers for attaching persisted channel connection ownership to inbound messages."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.channels.message_bus import InboundMessage
|
||||
|
||||
|
||||
async def attach_connection_identity(
|
||||
inbound: InboundMessage,
|
||||
*,
|
||||
repo: Any,
|
||||
provider: str,
|
||||
workspace_id: str | None,
|
||||
fallback_without_workspace: bool = False,
|
||||
) -> InboundMessage:
|
||||
"""Attach connection metadata to an inbound message when a persisted binding exists."""
|
||||
if repo is None:
|
||||
return inbound
|
||||
|
||||
workspace_candidates: list[str | None] = []
|
||||
if workspace_id:
|
||||
workspace_candidates.append(workspace_id)
|
||||
if fallback_without_workspace:
|
||||
workspace_candidates.append(None)
|
||||
if not workspace_candidates:
|
||||
return inbound
|
||||
|
||||
for candidate in workspace_candidates:
|
||||
connection = await repo.find_connection_by_external_identity(
|
||||
provider=provider,
|
||||
external_account_id=inbound.user_id,
|
||||
workspace_id=candidate,
|
||||
)
|
||||
if connection is None:
|
||||
continue
|
||||
|
||||
inbound.connection_id = connection["id"]
|
||||
inbound.owner_user_id = connection["owner_user_id"]
|
||||
inbound.workspace_id = connection.get("workspace_id")
|
||||
return inbound
|
||||
|
||||
return inbound
|
||||
@@ -14,7 +14,8 @@ from typing import Any
|
||||
import httpx
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -136,6 +137,7 @@ class DingTalkChannel(Channel):
|
||||
self._incoming_messages: dict[str, Any] = {}
|
||||
self._incoming_messages_lock = threading.Lock()
|
||||
self._card_repliers: dict[str, Any] = {}
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
@@ -395,6 +397,24 @@ class DingTalkChannel(Channel):
|
||||
text[:100],
|
||||
)
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and self._connection_repo is not None:
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
self._bind_connection_from_connect_code(
|
||||
conversation_type=conversation_type,
|
||||
sender_staff_id=sender_staff_id,
|
||||
sender_nick=sender_nick,
|
||||
conversation_id=conversation_id,
|
||||
code=connect_code,
|
||||
),
|
||||
self._main_loop,
|
||||
)
|
||||
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
|
||||
else:
|
||||
logger.warning("[DingTalk] main loop not running, cannot bind channel connection")
|
||||
return
|
||||
|
||||
if _is_dingtalk_command(text):
|
||||
msg_type = InboundMessageType.COMMAND
|
||||
else:
|
||||
@@ -450,11 +470,95 @@ class DingTalkChannel(Channel):
|
||||
return ""
|
||||
|
||||
async def _prepare_inbound(self, chat_id: str, inbound: InboundMessage) -> None:
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
# Running reply must finish before publish_inbound so AI card tracks are
|
||||
# registered before the manager emits streaming outbounds.
|
||||
await self._send_running_reply(chat_id, inbound)
|
||||
await self.bus.publish_inbound(inbound)
|
||||
|
||||
@staticmethod
|
||||
def _connection_workspace_id(conversation_type: str, conversation_id: str) -> str | None:
|
||||
if conversation_type == _CONVERSATION_TYPE_GROUP and conversation_id:
|
||||
return conversation_id
|
||||
return None
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||
conversation_type = str(inbound.metadata.get("conversation_type") or _CONVERSATION_TYPE_P2P)
|
||||
conversation_id = str(inbound.metadata.get("conversation_id") or "")
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="dingtalk",
|
||||
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
|
||||
fallback_without_workspace=True,
|
||||
)
|
||||
|
||||
async def _bind_connection_from_connect_code(
|
||||
self,
|
||||
*,
|
||||
conversation_type: str,
|
||||
sender_staff_id: str,
|
||||
sender_nick: str,
|
||||
conversation_id: str,
|
||||
code: str,
|
||||
) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="dingtalk", state=code)
|
||||
if state is None:
|
||||
await self._send_connection_reply(
|
||||
conversation_type,
|
||||
sender_staff_id,
|
||||
conversation_id,
|
||||
"DingTalk connection code is invalid or expired.",
|
||||
)
|
||||
return True
|
||||
|
||||
if not sender_staff_id:
|
||||
await self._send_connection_reply(
|
||||
conversation_type,
|
||||
sender_staff_id,
|
||||
conversation_id,
|
||||
"DingTalk connection could not be completed from this message.",
|
||||
)
|
||||
return True
|
||||
|
||||
await self._connection_repo.upsert_connection(
|
||||
owner_user_id=state["owner_user_id"],
|
||||
provider="dingtalk",
|
||||
external_account_id=sender_staff_id,
|
||||
external_account_name=sender_nick or None,
|
||||
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
|
||||
metadata={
|
||||
"conversation_type": conversation_type,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
await self._send_connection_reply(
|
||||
conversation_type,
|
||||
sender_staff_id,
|
||||
conversation_id,
|
||||
"DingTalk connected to DeerFlow.",
|
||||
)
|
||||
return True
|
||||
|
||||
async def _send_connection_reply(
|
||||
self,
|
||||
conversation_type: str,
|
||||
sender_staff_id: str,
|
||||
conversation_id: str,
|
||||
text: str,
|
||||
) -> None:
|
||||
robot_code = self._client_id
|
||||
if conversation_type == _CONVERSATION_TYPE_GROUP:
|
||||
if conversation_id:
|
||||
await self._send_text_message_to_group(robot_code, conversation_id, text)
|
||||
return
|
||||
if sender_staff_id:
|
||||
await self._send_text_message_to_user(robot_code, sender_staff_id, text)
|
||||
|
||||
async def _send_running_reply(self, chat_id: str, inbound: InboundMessage) -> None:
|
||||
conversation_type = inbound.metadata.get("conversation_type", _CONVERSATION_TYPE_P2P)
|
||||
sender_staff_id = inbound.metadata.get("sender_staff_id", "")
|
||||
|
||||
@@ -10,8 +10,9 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -70,6 +71,7 @@ class DiscordChannel(Channel):
|
||||
self._discord_loop: asyncio.AbstractEventLoop | None = None
|
||||
self._main_loop: asyncio.AbstractEventLoop | None = None
|
||||
self._discord_module = None
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
@@ -203,10 +205,14 @@ class DiscordChannel(Channel):
|
||||
return False
|
||||
|
||||
try:
|
||||
fp = open(str(attachment.actual_path), "rb") # noqa: SIM115
|
||||
file = self._discord_module.File(fp, filename=attachment.filename)
|
||||
send_future = asyncio.run_coroutine_threadsafe(target.send(file=file), self._discord_loop)
|
||||
await asyncio.wrap_future(send_future)
|
||||
# Keep the file handle open only for the duration of the upload: discord.py
|
||||
# reads ``fp`` while ``target.send`` runs on ``_discord_loop``; once that
|
||||
# future resolves the bytes are consumed, so closing here is safe and avoids
|
||||
# leaking the handle on both the success and failure paths.
|
||||
with open(str(attachment.actual_path), "rb") as fp:
|
||||
file = self._discord_module.File(fp, filename=attachment.filename)
|
||||
send_future = asyncio.run_coroutine_threadsafe(target.send(file=file), self._discord_loop)
|
||||
await asyncio.wrap_future(send_future)
|
||||
logger.info("[Discord] file uploaded: %s", attachment.filename)
|
||||
return True
|
||||
except Exception:
|
||||
@@ -287,6 +293,10 @@ class DiscordChannel(Channel):
|
||||
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
|
||||
# Don't return early if text is empty — still process the mention (e.g., create thread)
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and await self._bind_connection_from_connect_code(message, connect_code):
|
||||
return
|
||||
|
||||
# --- Determine thread/channel routing and typing target ---
|
||||
thread_id = None
|
||||
chat_id = None
|
||||
@@ -315,6 +325,7 @@ class DiscordChannel(Channel):
|
||||
},
|
||||
)
|
||||
inbound.topic_id = thread_id
|
||||
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
|
||||
self._publish(inbound)
|
||||
# Start typing indicator in the thread
|
||||
if typing_target:
|
||||
@@ -422,6 +433,7 @@ class DiscordChannel(Channel):
|
||||
},
|
||||
)
|
||||
inbound.topic_id = thread_id
|
||||
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
|
||||
|
||||
# Start typing indicator in the correct target (thread or channel)
|
||||
if typing_target:
|
||||
@@ -436,6 +448,60 @@ class DiscordChannel(Channel):
|
||||
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
|
||||
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage, guild_id: str | None = None) -> InboundMessage:
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="discord",
|
||||
workspace_id=guild_id,
|
||||
fallback_without_workspace=True,
|
||||
)
|
||||
|
||||
async def _bind_connection_from_connect_code(self, message, code: str) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="discord", state=code)
|
||||
if state is None:
|
||||
await self._send_connection_reply(message, "Discord connection code is invalid or expired.")
|
||||
return True
|
||||
|
||||
guild = getattr(message, "guild", None)
|
||||
channel = getattr(message, "channel", None)
|
||||
author = getattr(message, "author", None)
|
||||
user_id = str(getattr(author, "id", "") or "")
|
||||
if not user_id:
|
||||
await self._send_connection_reply(message, "Discord connection could not be completed from this message.")
|
||||
return True
|
||||
|
||||
guild_id = str(getattr(guild, "id", "") or "") or None
|
||||
await self._connection_repo.upsert_connection(
|
||||
owner_user_id=state["owner_user_id"],
|
||||
provider="discord",
|
||||
external_account_id=user_id,
|
||||
external_account_name=getattr(author, "display_name", None) or getattr(author, "name", None),
|
||||
workspace_id=guild_id,
|
||||
workspace_name=getattr(guild, "name", None) if guild is not None else None,
|
||||
metadata={
|
||||
"guild_id": guild_id,
|
||||
"channel_id": str(getattr(channel, "id", "") or ""),
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
await self._send_connection_reply(message, "Discord connected to DeerFlow.")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def _send_connection_reply(message, text: str) -> None:
|
||||
channel = getattr(message, "channel", None)
|
||||
send = getattr(channel, "send", None)
|
||||
if send is None:
|
||||
return
|
||||
try:
|
||||
await send(text)
|
||||
except Exception:
|
||||
logger.exception("[Discord] failed to send connection reply")
|
||||
|
||||
def _run_client(self) -> None:
|
||||
self._discord_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._discord_loop)
|
||||
|
||||
@@ -11,7 +11,8 @@ import time
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import (
|
||||
PENDING_CLARIFICATION_METADATA_KEY,
|
||||
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY,
|
||||
@@ -71,6 +72,7 @@ class FeishuChannel(Channel):
|
||||
self._CreateImageRequestBody = None
|
||||
self._GetMessageResourceRequest = None
|
||||
self._thread_lock = threading.Lock()
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
@staticmethod
|
||||
def _non_empty_str(value: Any) -> str | None:
|
||||
@@ -86,6 +88,23 @@ class FeishuChannel(Channel):
|
||||
def supports_streaming(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
if not self._running:
|
||||
return False
|
||||
return self._thread is not None and self._thread.is_alive()
|
||||
|
||||
def _build_event_handler(self, lark):
|
||||
return (
|
||||
lark.EventDispatcherHandler.builder("", "")
|
||||
.register_p2_im_message_receive_v1(self._on_message)
|
||||
.register_p2_im_message_message_read_v1(self._on_ignored_message_event)
|
||||
.register_p2_im_message_reaction_created_v1(self._on_ignored_message_event)
|
||||
.register_p2_im_message_reaction_deleted_v1(self._on_ignored_message_event)
|
||||
.register_p2_im_message_recalled_v1(self._on_ignored_message_event)
|
||||
.build()
|
||||
)
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
return
|
||||
@@ -179,7 +198,7 @@ class FeishuChannel(Channel):
|
||||
# thread's uvloop.
|
||||
_ws_client_mod.loop = loop
|
||||
|
||||
event_handler = lark.EventDispatcherHandler.builder("", "").register_p2_im_message_receive_v1(self._on_message).build()
|
||||
event_handler = self._build_event_handler(lark)
|
||||
ws_client = lark.ws.Client(
|
||||
app_id=app_id,
|
||||
app_secret=app_secret,
|
||||
@@ -191,6 +210,10 @@ class FeishuChannel(Channel):
|
||||
except Exception:
|
||||
if self._running:
|
||||
logger.exception("Feishu WebSocket error")
|
||||
self._running = False
|
||||
|
||||
def _on_ignored_message_event(self, event) -> None:
|
||||
logger.debug("[Feishu] ignoring non-content message event: %s", type(event).__name__)
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
@@ -726,11 +749,47 @@ class FeishuChannel(Channel):
|
||||
|
||||
async def _prepare_inbound(self, msg_id: str, inbound) -> None:
|
||||
"""Kick off Feishu side effects without delaying inbound dispatch."""
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
reaction_task = asyncio.create_task(self._add_reaction(msg_id, "OK"))
|
||||
self._track_background_task(reaction_task, name="add_reaction", msg_id=msg_id)
|
||||
self._ensure_running_card_started(msg_id)
|
||||
await self.bus.publish_inbound(inbound)
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="feishu",
|
||||
workspace_id=inbound.chat_id,
|
||||
)
|
||||
|
||||
async def _bind_connection_from_connect_code(self, *, message_id: str, chat_id: str, user_id: str, code: str) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="feishu", state=code)
|
||||
if state is None:
|
||||
await self._reply_card(message_id, "Feishu connection code is invalid or expired.")
|
||||
return True
|
||||
|
||||
if not user_id or not chat_id:
|
||||
await self._reply_card(message_id, "Feishu connection could not be completed from this message.")
|
||||
return True
|
||||
|
||||
await self._connection_repo.upsert_connection(
|
||||
owner_user_id=state["owner_user_id"],
|
||||
provider="feishu",
|
||||
external_account_id=user_id,
|
||||
workspace_id=chat_id,
|
||||
metadata={
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
await self._reply_card(message_id, "Feishu connected to DeerFlow.")
|
||||
return True
|
||||
|
||||
def _on_message(self, event) -> None:
|
||||
"""Called by lark-oapi when a message is received (runs in lark thread)."""
|
||||
try:
|
||||
@@ -819,6 +878,23 @@ class FeishuChannel(Channel):
|
||||
logger.info("[Feishu] empty text, ignoring message")
|
||||
return
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and self._connection_repo is not None:
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
self._bind_connection_from_connect_code(
|
||||
message_id=msg_id,
|
||||
chat_id=chat_id,
|
||||
user_id=sender_id,
|
||||
code=connect_code,
|
||||
),
|
||||
self._main_loop,
|
||||
)
|
||||
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
|
||||
else:
|
||||
logger.warning("[Feishu] main loop not running, cannot bind channel connection")
|
||||
return
|
||||
|
||||
# Only treat known slash commands as commands; absolute paths and
|
||||
# other slash-prefixed text should be handled as normal chat.
|
||||
if _is_feishu_command(text):
|
||||
|
||||
+168
-35
@@ -49,6 +49,11 @@ DEFAULT_RUN_CONTEXT: dict[str, Any] = {
|
||||
"subagent_enabled": False,
|
||||
}
|
||||
STREAM_UPDATE_MIN_INTERVAL_SECONDS = 0.35
|
||||
# Stream modes requested from the runtime, and the SSE event names under which
|
||||
# the message-tuple stream may arrive: the embedded runtime (and LangGraph
|
||||
# Platform) deliver the requested "messages-tuple" mode as event "messages".
|
||||
STREAM_MODES = ["messages-tuple", "values"]
|
||||
MESSAGE_STREAM_EVENTS = ("messages-tuple", "messages")
|
||||
THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again."
|
||||
|
||||
CHANNEL_CAPABILITIES = {
|
||||
@@ -56,7 +61,7 @@ CHANNEL_CAPABILITIES = {
|
||||
"discord": {"supports_streaming": False},
|
||||
"feishu": {"supports_streaming": True},
|
||||
"slack": {"supports_streaming": False},
|
||||
"telegram": {"supports_streaming": False},
|
||||
"telegram": {"supports_streaming": True},
|
||||
"wechat": {"supports_streaming": False},
|
||||
"wecom": {"supports_streaming": True},
|
||||
}
|
||||
@@ -274,6 +279,22 @@ def _response_metadata(base_metadata: dict[str, Any], *, pending_clarification:
|
||||
return metadata
|
||||
|
||||
|
||||
def _thread_channel_metadata(msg: InboundMessage) -> dict[str, Any]:
|
||||
channel_source: dict[str, Any] = {
|
||||
"type": "im_channel",
|
||||
"provider": msg.channel_name,
|
||||
"chat_id": msg.chat_id,
|
||||
}
|
||||
if msg.topic_id:
|
||||
channel_source["topic_id"] = msg.topic_id
|
||||
if msg.thread_ts:
|
||||
channel_source["thread_ts"] = msg.thread_ts
|
||||
if msg.connection_id:
|
||||
channel_source["connection_id"] = msg.connection_id
|
||||
|
||||
return {"channel_source": channel_source}
|
||||
|
||||
|
||||
def _extract_text_content(content: Any) -> str:
|
||||
"""Extract text from a streaming payload content field."""
|
||||
if isinstance(content, str):
|
||||
@@ -440,6 +461,43 @@ def _human_input_message(content: str, *, original_content: str | None = None) -
|
||||
return message
|
||||
|
||||
|
||||
def _auth_disabled_owner_user_id() -> str | None:
|
||||
try:
|
||||
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled
|
||||
except Exception:
|
||||
logger.debug("Unable to inspect auth-disabled mode for channel owner fallback", exc_info=True)
|
||||
return None
|
||||
return AUTH_DISABLED_USER_ID if is_auth_disabled() else None
|
||||
|
||||
|
||||
def _effective_owner_user_id(msg: InboundMessage) -> str | None:
|
||||
return _auth_disabled_owner_user_id() or msg.owner_user_id
|
||||
|
||||
|
||||
def _apply_effective_owner(msg: InboundMessage) -> InboundMessage:
|
||||
owner_user_id = _effective_owner_user_id(msg)
|
||||
if owner_user_id:
|
||||
msg.owner_user_id = owner_user_id
|
||||
return msg
|
||||
|
||||
|
||||
def _owner_headers(msg: InboundMessage) -> dict[str, str] | None:
|
||||
owner_user_id = _effective_owner_user_id(msg)
|
||||
if not owner_user_id:
|
||||
return None
|
||||
return create_internal_auth_headers(owner_user_id=owner_user_id)
|
||||
|
||||
|
||||
def _safe_user_id_for_run(raw_user_id: str) -> str:
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
try:
|
||||
return get_paths().prepare_user_dir_for_raw_id(raw_user_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to prepare channel run user directory")
|
||||
return make_safe_user_id(raw_user_id)
|
||||
|
||||
|
||||
def _resolve_slash_skill_command(
|
||||
text: str,
|
||||
available_skills: set[str] | None = None,
|
||||
@@ -555,8 +613,14 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
|
||||
write_upload_file_no_symlink,
|
||||
)
|
||||
|
||||
uploads_dir = ensure_uploads_dir(thread_id)
|
||||
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
|
||||
def _prepare_uploads_dir() -> tuple[Path, set[str]]:
|
||||
# Worker thread: ensure_uploads_dir's mkdir and the iterdir enumeration are
|
||||
# blocking filesystem IO that must stay off the event loop.
|
||||
target = ensure_uploads_dir(thread_id)
|
||||
existing = {entry.name for entry in target.iterdir() if entry.is_file()}
|
||||
return target, existing
|
||||
|
||||
uploads_dir, seen_names = await asyncio.to_thread(_prepare_uploads_dir)
|
||||
|
||||
created: list[dict[str, Any]] = []
|
||||
file_reader = INBOUND_FILE_READERS.get(msg.channel_name, _read_http_inbound_file)
|
||||
@@ -604,7 +668,7 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
|
||||
|
||||
dest = uploads_dir / safe_name
|
||||
try:
|
||||
dest = write_upload_file_no_symlink(uploads_dir, safe_name, data)
|
||||
dest = await asyncio.to_thread(write_upload_file_no_symlink, uploads_dir, safe_name, data)
|
||||
except UnsafeUploadPathError:
|
||||
logger.warning("[Manager] skipping inbound file with unsafe destination: %s", safe_name)
|
||||
continue
|
||||
@@ -670,6 +734,7 @@ class ChannelManager:
|
||||
assistant_id: str = DEFAULT_ASSISTANT_ID,
|
||||
default_session: dict[str, Any] | None = None,
|
||||
channel_sessions: dict[str, Any] | None = None,
|
||||
connection_repo: Any | None = None,
|
||||
) -> None:
|
||||
self.bus = bus
|
||||
self.store = store
|
||||
@@ -679,7 +744,9 @@ class ChannelManager:
|
||||
self._assistant_id = assistant_id
|
||||
self._default_session = _as_dict(default_session)
|
||||
self._channel_sessions = dict(channel_sessions or {})
|
||||
self._connection_repo = connection_repo
|
||||
self._client = None # lazy init — langgraph_sdk async client
|
||||
self._channel_metadata_synced: set[str] = set()
|
||||
self._skill_storage: SkillStorage | None = None
|
||||
self._csrf_token = generate_csrf_token()
|
||||
self._semaphore: asyncio.Semaphore | None = None
|
||||
@@ -728,12 +795,17 @@ class ChannelManager:
|
||||
configurable["checkpoint_ns"] = ""
|
||||
configurable["thread_id"] = thread_id
|
||||
|
||||
# ``user_id`` drives user-scoped filesystem buckets that only accept
|
||||
# ``[A-Za-z0-9_-]``, so normalize the channel id and keep the raw value
|
||||
# under ``channel_user_id`` for platform-facing lookups.
|
||||
# ``user_id`` drives DeerFlow-owned memory, files, and thread buckets.
|
||||
# For browser-connected IM channels, prefer the DeerFlow account that
|
||||
# owns the connection. Preserve the raw platform user under
|
||||
# ``channel_user_id`` for platform-facing lookups and audits.
|
||||
run_context_identity: dict[str, Any] = {"thread_id": thread_id}
|
||||
owner_user_id = _effective_owner_user_id(msg)
|
||||
if owner_user_id:
|
||||
run_context_identity["user_id"] = _safe_user_id_for_run(owner_user_id)
|
||||
elif msg.user_id:
|
||||
run_context_identity["user_id"] = _safe_user_id_for_run(msg.user_id)
|
||||
if msg.user_id:
|
||||
run_context_identity["user_id"] = make_safe_user_id(msg.user_id)
|
||||
run_context_identity["channel_user_id"] = msg.user_id
|
||||
|
||||
run_context = _merge_dicts(
|
||||
@@ -845,6 +917,7 @@ class ChannelManager:
|
||||
logger.error("[Manager] unhandled error in message task: %s", exc, exc_info=exc)
|
||||
|
||||
async def _handle_message(self, msg: InboundMessage) -> None:
|
||||
msg = _apply_effective_owner(msg)
|
||||
async with self._semaphore:
|
||||
try:
|
||||
if msg.msg_type == InboundMessageType.COMMAND:
|
||||
@@ -877,10 +950,27 @@ class ChannelManager:
|
||||
|
||||
# -- chat handling -----------------------------------------------------
|
||||
|
||||
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
||||
"""Create a new thread through Gateway and store the mapping."""
|
||||
thread = await client.threads.create()
|
||||
thread_id = thread["thread_id"]
|
||||
async def _lookup_thread_id(self, msg: InboundMessage) -> str | None:
|
||||
if msg.connection_id and self._connection_repo is not None:
|
||||
return await self._connection_repo.get_thread_id(
|
||||
msg.connection_id,
|
||||
msg.chat_id,
|
||||
msg.topic_id,
|
||||
)
|
||||
return self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
||||
|
||||
async def _store_thread_id(self, msg: InboundMessage, thread_id: str) -> None:
|
||||
if msg.connection_id and msg.owner_user_id and self._connection_repo is not None:
|
||||
await self._connection_repo.set_thread_id(
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
provider=msg.channel_name,
|
||||
external_conversation_id=msg.chat_id,
|
||||
external_topic_id=msg.topic_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
return
|
||||
|
||||
self.store.set_thread_id(
|
||||
msg.channel_name,
|
||||
msg.chat_id,
|
||||
@@ -888,18 +978,49 @@ class ChannelManager:
|
||||
topic_id=msg.topic_id,
|
||||
user_id=msg.user_id,
|
||||
)
|
||||
|
||||
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
||||
"""Create a new thread through Gateway and store the mapping."""
|
||||
metadata = _thread_channel_metadata(msg)
|
||||
owner_headers = _owner_headers(msg)
|
||||
if owner_headers:
|
||||
thread = await client.threads.create(metadata=metadata, headers=owner_headers)
|
||||
else:
|
||||
thread = await client.threads.create(metadata=metadata)
|
||||
thread_id = thread["thread_id"]
|
||||
await self._store_thread_id(msg, thread_id)
|
||||
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
|
||||
return thread_id
|
||||
|
||||
async def _update_thread_channel_metadata(self, client, msg: InboundMessage, thread_id: str) -> None:
|
||||
"""Best-effort source metadata backfill for existing IM-created threads."""
|
||||
# The metadata (provider/chat/topic) is constant for a thread, so one
|
||||
# successful backfill per manager lifetime is enough — skip the
|
||||
# redundant PATCH on every subsequent inbound message.
|
||||
if thread_id in self._channel_metadata_synced:
|
||||
return
|
||||
update_kwargs: dict[str, Any] = {"metadata": _thread_channel_metadata(msg)}
|
||||
if owner_headers := _owner_headers(msg):
|
||||
update_kwargs["headers"] = owner_headers
|
||||
try:
|
||||
await client.threads.update(thread_id, **update_kwargs)
|
||||
except Exception:
|
||||
logger.debug("[Manager] failed to update channel metadata for thread_id=%s", thread_id, exc_info=True)
|
||||
return
|
||||
if len(self._channel_metadata_synced) > 4096:
|
||||
self._channel_metadata_synced.clear()
|
||||
self._channel_metadata_synced.add(thread_id)
|
||||
|
||||
async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None:
|
||||
client = self._get_client()
|
||||
|
||||
# Look up existing DeerFlow thread.
|
||||
# topic_id may be None (e.g. Telegram private chats) — the store
|
||||
# handles this by using the "channel:chat_id" key without a topic suffix.
|
||||
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
||||
thread_id = await self._lookup_thread_id(msg)
|
||||
if thread_id:
|
||||
logger.info("[Manager] reusing thread: thread_id=%s for topic_id=%s", thread_id, msg.topic_id)
|
||||
await self._update_thread_channel_metadata(client, msg, thread_id)
|
||||
|
||||
# No existing thread found — create a new one
|
||||
if thread_id is None:
|
||||
@@ -940,14 +1061,19 @@ class ChannelManager:
|
||||
return
|
||||
|
||||
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
||||
run_kwargs: dict[str, Any] = {
|
||||
"input": {"messages": [human_message]},
|
||||
"config": run_config,
|
||||
"context": run_context,
|
||||
"multitask_strategy": "reject",
|
||||
}
|
||||
if owner_headers := _owner_headers(msg):
|
||||
run_kwargs["headers"] = owner_headers
|
||||
try:
|
||||
result = await client.runs.wait(
|
||||
thread_id,
|
||||
assistant_id,
|
||||
input={"messages": [human_message]},
|
||||
config=run_config,
|
||||
context=run_context,
|
||||
multitask_strategy="reject",
|
||||
**run_kwargs,
|
||||
)
|
||||
except Exception as exc:
|
||||
if _is_thread_busy_error(exc):
|
||||
@@ -984,6 +1110,8 @@ class ChannelManager:
|
||||
artifacts=artifacts,
|
||||
attachments=attachments,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
|
||||
)
|
||||
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
|
||||
@@ -1008,21 +1136,26 @@ class ChannelManager:
|
||||
last_published_text = ""
|
||||
last_publish_at = 0.0
|
||||
stream_error: BaseException | None = None
|
||||
stream_kwargs: dict[str, Any] = {
|
||||
"input": {"messages": [human_message]},
|
||||
"config": run_config,
|
||||
"context": run_context,
|
||||
"stream_mode": list(STREAM_MODES),
|
||||
"multitask_strategy": "reject",
|
||||
}
|
||||
if owner_headers := _owner_headers(msg):
|
||||
stream_kwargs["headers"] = owner_headers
|
||||
|
||||
try:
|
||||
async for chunk in client.runs.stream(
|
||||
thread_id,
|
||||
assistant_id,
|
||||
input={"messages": [human_message]},
|
||||
config=run_config,
|
||||
context=run_context,
|
||||
stream_mode=["messages-tuple", "values"],
|
||||
multitask_strategy="reject",
|
||||
**stream_kwargs,
|
||||
):
|
||||
event = getattr(chunk, "event", "")
|
||||
data = getattr(chunk, "data", None)
|
||||
|
||||
if event == "messages-tuple":
|
||||
if event in MESSAGE_STREAM_EVENTS:
|
||||
accumulated_text, current_message_id = _accumulate_stream_text(streamed_buffers, current_message_id, data)
|
||||
if accumulated_text:
|
||||
latest_text = accumulated_text
|
||||
@@ -1047,6 +1180,8 @@ class ChannelManager:
|
||||
text=latest_text,
|
||||
is_final=False,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_response_metadata(msg.metadata),
|
||||
)
|
||||
)
|
||||
@@ -1093,6 +1228,8 @@ class ChannelManager:
|
||||
attachments=attachments,
|
||||
is_final=True,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
|
||||
)
|
||||
)
|
||||
@@ -1124,18 +1261,10 @@ class ChannelManager:
|
||||
if reply is None and command == "new":
|
||||
# Create a new thread through Gateway
|
||||
client = self._get_client()
|
||||
thread = await client.threads.create()
|
||||
new_thread_id = thread["thread_id"]
|
||||
self.store.set_thread_id(
|
||||
msg.channel_name,
|
||||
msg.chat_id,
|
||||
new_thread_id,
|
||||
topic_id=msg.topic_id,
|
||||
user_id=msg.user_id,
|
||||
)
|
||||
await self._create_thread(client, msg)
|
||||
reply = "New conversation started."
|
||||
elif reply is None and command == "status":
|
||||
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
||||
thread_id = await self._lookup_thread_id(msg)
|
||||
reply = f"Active thread: {thread_id}" if thread_id else "No active conversation."
|
||||
elif reply is None and command == "models":
|
||||
reply = await self._fetch_gateway("/api/models", "models")
|
||||
@@ -1174,9 +1303,11 @@ class ChannelManager:
|
||||
outbound = OutboundMessage(
|
||||
channel_name=msg.channel_name,
|
||||
chat_id=msg.chat_id,
|
||||
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "",
|
||||
thread_id=await self._lookup_thread_id(msg) or "",
|
||||
text=reply,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_slim_metadata(msg.metadata),
|
||||
)
|
||||
await self.bus.publish_outbound(outbound)
|
||||
@@ -1212,9 +1343,11 @@ class ChannelManager:
|
||||
outbound = OutboundMessage(
|
||||
channel_name=msg.channel_name,
|
||||
chat_id=msg.chat_id,
|
||||
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "",
|
||||
thread_id=await self._lookup_thread_id(msg) or "",
|
||||
text=error_text,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_slim_metadata(msg.metadata),
|
||||
)
|
||||
await self.bus.publish_outbound(outbound)
|
||||
|
||||
@@ -44,6 +44,12 @@ class InboundMessage:
|
||||
Messages sharing the same ``topic_id`` within a ``chat_id`` will
|
||||
reuse the same DeerFlow thread. When ``None``, each message
|
||||
creates a new thread (one-shot Q&A).
|
||||
connection_id: Optional DeerFlow channel connection id. When present,
|
||||
conversation mapping is scoped by the connection instead of the
|
||||
legacy global ``channel_name:chat_id[:topic_id]`` key.
|
||||
owner_user_id: DeerFlow user id that owns the channel connection.
|
||||
Platform user ids stay in ``user_id``.
|
||||
workspace_id: Optional external workspace/guild/team id.
|
||||
files: Optional list of file attachments (platform-specific dicts).
|
||||
metadata: Arbitrary extra data from the channel.
|
||||
created_at: Unix timestamp when the message was created.
|
||||
@@ -56,6 +62,9 @@ class InboundMessage:
|
||||
msg_type: InboundMessageType = InboundMessageType.CHAT
|
||||
thread_ts: str | None = None
|
||||
topic_id: str | None = None
|
||||
connection_id: str | None = None
|
||||
owner_user_id: str | None = None
|
||||
workspace_id: str | None = None
|
||||
files: list[dict[str, Any]] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
created_at: float = field(default_factory=time.time)
|
||||
@@ -95,6 +104,9 @@ class OutboundMessage:
|
||||
is_final: Whether this is the final message in the response stream.
|
||||
thread_ts: Optional platform thread identifier for threaded replies.
|
||||
metadata: Arbitrary extra data.
|
||||
connection_id: Optional DeerFlow channel connection id used for
|
||||
connection-specific outbound credentials.
|
||||
owner_user_id: DeerFlow user id that owns the channel connection.
|
||||
created_at: Unix timestamp.
|
||||
"""
|
||||
|
||||
@@ -106,6 +118,8 @@ class OutboundMessage:
|
||||
attachments: list[ResolvedAttachment] = field(default_factory=list)
|
||||
is_final: bool = True
|
||||
thread_ts: str | None = None
|
||||
connection_id: str | None = None
|
||||
owner_user_id: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@@ -0,0 +1,154 @@
|
||||
"""Local persistence for runtime IM channel configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RUNTIME_CHANNEL_DISABLED_FLAG = "_runtime_disabled"
|
||||
|
||||
|
||||
class ChannelRuntimeConfigStore:
|
||||
"""JSON-backed store for channel credentials entered from the UI.
|
||||
|
||||
This intentionally mirrors ``ChannelStore``: local/private deployments get
|
||||
durable runtime configuration without needing a public callback URL or a
|
||||
config.yaml edit.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str | Path | None = None) -> None:
|
||||
if path is None:
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
path = Path(get_paths().base_dir) / "channels" / "runtime-config.json"
|
||||
self._path = Path(path)
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._data: dict[str, dict[str, Any]] = self._load()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _load(self) -> dict[str, dict[str, Any]]:
|
||||
if self._path.exists():
|
||||
try:
|
||||
raw = json.loads(self._path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
logger.warning("Corrupt channel runtime config store at %s, starting fresh", self._path)
|
||||
return {}
|
||||
if isinstance(raw, dict):
|
||||
return {str(name): dict(value) for name, value in raw.items() if isinstance(value, dict)}
|
||||
return {}
|
||||
|
||||
def _save(self) -> None:
|
||||
fd = tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
dir=self._path.parent,
|
||||
suffix=".tmp",
|
||||
delete=False,
|
||||
)
|
||||
try:
|
||||
json.dump(self._data, fd, indent=2, ensure_ascii=False)
|
||||
fd.close()
|
||||
Path(fd.name).replace(self._path)
|
||||
try:
|
||||
self._path.chmod(0o600)
|
||||
except OSError:
|
||||
logger.debug("Unable to chmod channel runtime config store at %s", self._path, exc_info=True)
|
||||
except BaseException:
|
||||
fd.close()
|
||||
Path(fd.name).unlink(missing_ok=True)
|
||||
raise
|
||||
|
||||
def load_all(self) -> dict[str, dict[str, Any]]:
|
||||
with self._lock:
|
||||
return {name: dict(config) for name, config in self._data.items()}
|
||||
|
||||
def get_provider_config(self, provider: str) -> dict[str, Any] | None:
|
||||
with self._lock:
|
||||
config = self._data.get(provider)
|
||||
return dict(config) if isinstance(config, dict) else None
|
||||
|
||||
def set_provider_config(self, provider: str, config: dict[str, Any]) -> None:
|
||||
with self._lock:
|
||||
self._data[provider] = dict(config)
|
||||
self._save()
|
||||
|
||||
def set_provider_disconnected(self, provider: str) -> None:
|
||||
with self._lock:
|
||||
self._data[provider] = {
|
||||
"enabled": False,
|
||||
RUNTIME_CHANNEL_DISABLED_FLAG: True,
|
||||
}
|
||||
self._save()
|
||||
|
||||
def remove_provider_config(self, provider: str) -> bool:
|
||||
with self._lock:
|
||||
if provider not in self._data:
|
||||
return False
|
||||
del self._data[provider]
|
||||
self._save()
|
||||
return True
|
||||
|
||||
|
||||
def _provider_enabled(channel_connections_config: Any, provider: str) -> bool:
|
||||
provider_config = getattr(channel_connections_config, provider, None)
|
||||
return bool(getattr(provider_config, "enabled", False))
|
||||
|
||||
|
||||
def _runtime_channel_disconnected(runtime_config: dict[str, Any]) -> bool:
|
||||
return runtime_config.get(RUNTIME_CHANNEL_DISABLED_FLAG) is True and runtime_config.get("enabled") is False
|
||||
|
||||
|
||||
def merge_runtime_channel_configs(
|
||||
channels_config: dict[str, Any],
|
||||
channel_connections_config: Any,
|
||||
*,
|
||||
store: ChannelRuntimeConfigStore | None = None,
|
||||
) -> None:
|
||||
"""Merge persisted runtime provider config into ``channels_config`` in-place."""
|
||||
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
|
||||
return
|
||||
|
||||
runtime_store = store or ChannelRuntimeConfigStore()
|
||||
for provider, runtime_config in runtime_store.load_all().items():
|
||||
if not _provider_enabled(channel_connections_config, provider):
|
||||
continue
|
||||
if _runtime_channel_disconnected(runtime_config):
|
||||
channels_config.pop(provider, None)
|
||||
continue
|
||||
existing = channels_config.get(provider)
|
||||
merged = dict(runtime_config)
|
||||
if isinstance(existing, dict):
|
||||
merged.update(existing)
|
||||
channels_config[provider] = merged
|
||||
|
||||
|
||||
def apply_runtime_connection_config(
|
||||
channel_connections_config: Any,
|
||||
*,
|
||||
store: ChannelRuntimeConfigStore | None = None,
|
||||
) -> Any:
|
||||
"""Apply persisted connection metadata that lives outside ``channels``.
|
||||
|
||||
Telegram uses a bot username for deep links; UI-entered values are stored
|
||||
with the runtime channel config so local restarts keep the provider
|
||||
configured.
|
||||
"""
|
||||
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
|
||||
return channel_connections_config
|
||||
|
||||
runtime_store = store or ChannelRuntimeConfigStore()
|
||||
telegram_runtime_config = runtime_store.get_provider_config("telegram")
|
||||
bot_username = ""
|
||||
if isinstance(telegram_runtime_config, dict):
|
||||
bot_username = str(telegram_runtime_config.get("bot_username") or "").strip()
|
||||
if not bot_username or not _provider_enabled(channel_connections_config, "telegram"):
|
||||
return channel_connections_config
|
||||
|
||||
config = channel_connections_config.model_copy(deep=True)
|
||||
config.telegram.bot_username = bot_username
|
||||
return config
|
||||
+170
-25
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@@ -9,6 +10,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from app.channels.base import Channel
|
||||
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
|
||||
from app.channels.message_bus import MessageBus
|
||||
from app.channels.runtime_config_store import merge_runtime_channel_configs
|
||||
from app.channels.store import ChannelStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -42,6 +44,11 @@ _CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
|
||||
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
|
||||
|
||||
|
||||
def _channel_has_credentials(name: str, channel_config: dict[str, Any]) -> bool:
|
||||
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
|
||||
return any(not isinstance(channel_config.get(key), bool) and channel_config.get(key) is not None and str(channel_config[key]).strip() for key in cred_keys)
|
||||
|
||||
|
||||
def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str, default: str) -> str:
|
||||
value = config.pop(config_key, None)
|
||||
if isinstance(value, str) and value.strip():
|
||||
@@ -52,6 +59,30 @@ def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str,
|
||||
return default
|
||||
|
||||
|
||||
def _merge_channel_connection_runtime_config(channels_config: dict[str, Any], app_config: AppConfig) -> None:
|
||||
connection_config = getattr(app_config, "channel_connections", None)
|
||||
merge_runtime_channel_configs(channels_config, connection_config)
|
||||
|
||||
|
||||
def _make_connection_repo(app_config: AppConfig):
|
||||
connection_config = getattr(app_config, "channel_connections", None)
|
||||
if connection_config is None or not getattr(connection_config, "enabled", False):
|
||||
return None
|
||||
|
||||
try:
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
except Exception:
|
||||
logger.exception("Failed to import channel connection repository")
|
||||
return None
|
||||
|
||||
session_factory = get_session_factory()
|
||||
if session_factory is None:
|
||||
logger.warning("Channel connections are enabled but database persistence is not available")
|
||||
return None
|
||||
return ChannelConnectionRepository(session_factory)
|
||||
|
||||
|
||||
class ChannelService:
|
||||
"""Manages the lifecycle of all configured IM channels.
|
||||
|
||||
@@ -59,9 +90,10 @@ class ChannelService:
|
||||
instantiates enabled channels, and starts the ChannelManager dispatcher.
|
||||
"""
|
||||
|
||||
def __init__(self, channels_config: dict[str, Any] | None = None) -> None:
|
||||
def __init__(self, channels_config: dict[str, Any] | None = None, *, connection_repo: Any | None = None) -> None:
|
||||
self.bus = MessageBus()
|
||||
self.store = ChannelStore()
|
||||
self._connection_repo = connection_repo
|
||||
config = dict(channels_config or {})
|
||||
langgraph_url = _resolve_service_url(config, "langgraph_url", _CHANNELS_LANGGRAPH_URL_ENV, DEFAULT_LANGGRAPH_URL)
|
||||
gateway_url = _resolve_service_url(config, "gateway_url", _CHANNELS_GATEWAY_URL_ENV, DEFAULT_GATEWAY_URL)
|
||||
@@ -74,10 +106,12 @@ class ChannelService:
|
||||
gateway_url=gateway_url,
|
||||
default_session=default_session if isinstance(default_session, dict) else None,
|
||||
channel_sessions=channel_sessions,
|
||||
connection_repo=connection_repo,
|
||||
)
|
||||
self._channels: dict[str, Any] = {} # name -> Channel instance
|
||||
self._config = config
|
||||
self._running = False
|
||||
self._readiness_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
@classmethod
|
||||
def from_app_config(cls, app_config: AppConfig | None = None) -> ChannelService:
|
||||
@@ -90,8 +124,9 @@ class ChannelService:
|
||||
# extra fields are allowed by AppConfig (extra="allow")
|
||||
extra = app_config.model_extra or {}
|
||||
if "channels" in extra:
|
||||
channels_config = extra["channels"]
|
||||
return cls(channels_config=channels_config)
|
||||
channels_config = dict(extra["channels"] or {})
|
||||
_merge_channel_connection_runtime_config(channels_config, app_config)
|
||||
return cls(channels_config=channels_config, connection_repo=_make_connection_repo(app_config))
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the manager and all enabled channels."""
|
||||
@@ -99,63 +134,169 @@ class ChannelService:
|
||||
return
|
||||
|
||||
await self.manager.start()
|
||||
self._running = True
|
||||
|
||||
ready_status = await self.ensure_ready_channels(attempts=2)
|
||||
ready_count = sum(1 for ready in ready_status.values() if ready)
|
||||
logger.info("ChannelService started with %d/%d ready channels", ready_count, len(ready_status))
|
||||
|
||||
async def ensure_ready_channels(self, *, attempts: int = 1) -> dict[str, bool]:
|
||||
"""Start or restart enabled configured channels that are not ready."""
|
||||
ready_status: dict[str, bool] = {}
|
||||
for name, channel_config in self._config.items():
|
||||
if not isinstance(channel_config, dict):
|
||||
continue
|
||||
if not channel_config.get("enabled", False):
|
||||
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
|
||||
has_creds = any(not isinstance(channel_config.get(k), bool) and channel_config.get(k) is not None and str(channel_config[k]).strip() for k in cred_keys)
|
||||
if has_creds:
|
||||
if _channel_has_credentials(name, channel_config):
|
||||
logger.warning(
|
||||
"Channel '%s' has credentials configured but is disabled. Set enabled: true under channels.%s in config.yaml to activate it.",
|
||||
name,
|
||||
name,
|
||||
"A configured channel has credentials configured but is disabled. Set enabled: true under its channels entry in config.yaml to activate it.",
|
||||
)
|
||||
else:
|
||||
logger.info("Channel %s is disabled, skipping", name)
|
||||
logger.info("A configured channel is disabled, skipping")
|
||||
continue
|
||||
|
||||
await self._start_channel(name, channel_config)
|
||||
ready_status[name] = await self.ensure_channel_ready(name, attempts=attempts)
|
||||
return ready_status
|
||||
|
||||
self._running = True
|
||||
logger.info("ChannelService started with channels: %s", list(self._channels.keys()))
|
||||
async def ensure_channel_ready(
|
||||
self,
|
||||
name: str,
|
||||
config: dict[str, Any] | None = None,
|
||||
*,
|
||||
attempts: int = 1,
|
||||
) -> bool:
|
||||
"""Ensure a single enabled channel is running using its current config."""
|
||||
if not self._running:
|
||||
logger.warning("ChannelService is not running; cannot ensure channel readiness")
|
||||
return False
|
||||
|
||||
if config is not None:
|
||||
self._config[name] = dict(config)
|
||||
|
||||
# Serialize per channel: readiness is polled from request handlers, so
|
||||
# concurrent calls must not stop/start the same channel worker twice.
|
||||
lock = self._readiness_locks.setdefault(name, asyncio.Lock())
|
||||
async with lock:
|
||||
channel_config = self._config.get(name)
|
||||
if not channel_config or not isinstance(channel_config, dict):
|
||||
logger.warning("No config for requested channel")
|
||||
return False
|
||||
if not channel_config.get("enabled", False):
|
||||
return False
|
||||
|
||||
channel = self._channels.get(name)
|
||||
if channel is not None and channel.is_running:
|
||||
return True
|
||||
|
||||
if channel is not None:
|
||||
try:
|
||||
await channel.stop()
|
||||
except Exception:
|
||||
logger.exception("Error stopping non-running channel before readiness retry")
|
||||
self._channels.pop(name, None)
|
||||
|
||||
max_attempts = max(1, attempts)
|
||||
for attempt in range(max_attempts):
|
||||
if attempt > 0:
|
||||
logger.info("Retrying channel startup after readiness check")
|
||||
if await self._start_channel(name, channel_config):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop all channels and the manager."""
|
||||
for name, channel in list(self._channels.items()):
|
||||
try:
|
||||
await channel.stop()
|
||||
logger.info("Channel %s stopped", name)
|
||||
logger.info("Channel stopped")
|
||||
except Exception:
|
||||
logger.exception("Error stopping channel %s", name)
|
||||
logger.exception("Error stopping channel")
|
||||
self._channels.clear()
|
||||
|
||||
await self.manager.stop()
|
||||
self._running = False
|
||||
logger.info("ChannelService stopped")
|
||||
|
||||
async def restart_channel(self, name: str) -> bool:
|
||||
def _load_channel_config(self, name: str) -> dict[str, Any] | None:
|
||||
"""Load the latest config for a specific channel from disk.
|
||||
|
||||
Uses ``get_app_config()`` which detects file changes via mtime,
|
||||
so edits to ``config.yaml`` are picked up without a process restart.
|
||||
The UI runtime-config overlay applied at startup is re-applied here
|
||||
so a file-driven reload neither drops credentials entered from the
|
||||
browser nor resurrects a channel disconnected from it.
|
||||
Falls back to the cached ``self._config`` when config loading fails.
|
||||
"""
|
||||
try:
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
app_config = get_app_config()
|
||||
extra = app_config.model_extra or {}
|
||||
channels_config = dict(extra.get("channels") or {})
|
||||
_merge_channel_connection_runtime_config(channels_config, app_config)
|
||||
channel_config = channels_config.get(name)
|
||||
if isinstance(channel_config, dict):
|
||||
# Update the cached config so get_status() stays consistent.
|
||||
self._config[name] = channel_config
|
||||
return channel_config
|
||||
except Exception:
|
||||
logger.exception("Failed to reload config for channel %s, using cached version", name)
|
||||
return self._config.get(name)
|
||||
|
||||
async def restart_channel(self, name: str, *, reload_config: bool = True) -> bool:
|
||||
"""Restart a specific channel. Returns True if successful."""
|
||||
if name in self._channels:
|
||||
try:
|
||||
await self._channels[name].stop()
|
||||
except Exception:
|
||||
logger.exception("Error stopping channel %s for restart", name)
|
||||
logger.exception("Error stopping channel for restart")
|
||||
del self._channels[name]
|
||||
|
||||
config = self._config.get(name)
|
||||
if reload_config:
|
||||
# Reading config.yaml and the runtime store is disk IO; keep it
|
||||
# off the event loop.
|
||||
config = await asyncio.to_thread(self._load_channel_config, name)
|
||||
else:
|
||||
config = self._config.get(name)
|
||||
if not config or not isinstance(config, dict):
|
||||
logger.warning("No config for channel %s", name)
|
||||
logger.warning("No config for requested channel")
|
||||
return False
|
||||
|
||||
if not config.get("enabled", False):
|
||||
logger.info("Channel %s is disabled, skipping restart", name)
|
||||
return True
|
||||
|
||||
return await self._start_channel(name, config)
|
||||
|
||||
async def configure_channel(self, name: str, config: dict[str, Any]) -> bool:
|
||||
"""Apply runtime config for a channel and restart it if the service is running."""
|
||||
self._config[name] = dict(config)
|
||||
if not self._running:
|
||||
return True
|
||||
# The caller just supplied the authoritative config (e.g. credentials
|
||||
# entered in the browser that are never written to config.yaml) — a
|
||||
# file reload here would clobber it with the stale on-disk entry.
|
||||
return await self.restart_channel(name, reload_config=False)
|
||||
|
||||
async def remove_channel(self, name: str) -> bool:
|
||||
"""Remove runtime config for a channel and stop it if currently running."""
|
||||
self._config.pop(name, None)
|
||||
channel = self._channels.pop(name, None)
|
||||
if channel is None:
|
||||
return True
|
||||
try:
|
||||
await channel.stop()
|
||||
logger.info("Channel stopped and removed")
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Error stopping channel for removal")
|
||||
return False
|
||||
|
||||
async def _start_channel(self, name: str, config: dict[str, Any]) -> bool:
|
||||
"""Instantiate and start a single channel."""
|
||||
import_path = _CHANNEL_REGISTRY.get(name)
|
||||
if not import_path:
|
||||
logger.warning("Unknown channel type: %s", name)
|
||||
logger.warning("Unknown channel type")
|
||||
return False
|
||||
|
||||
try:
|
||||
@@ -163,24 +304,26 @@ class ChannelService:
|
||||
|
||||
channel_cls = resolve_class(import_path, base_class=None)
|
||||
except Exception:
|
||||
logger.exception("Failed to import channel class for %s", name)
|
||||
logger.exception("Failed to import channel class")
|
||||
return False
|
||||
|
||||
try:
|
||||
config = dict(config)
|
||||
config["channel_store"] = self.store
|
||||
if self._connection_repo is not None:
|
||||
config["connection_repo"] = self._connection_repo
|
||||
channel = channel_cls(bus=self.bus, config=config)
|
||||
self._channels[name] = channel
|
||||
await channel.start()
|
||||
if not channel.is_running:
|
||||
self._channels.pop(name, None)
|
||||
logger.error("Channel %s did not enter a running state after start()", name)
|
||||
logger.error("Channel did not enter a running state after start()")
|
||||
return False
|
||||
logger.info("Channel %s started", name)
|
||||
logger.info("Channel started")
|
||||
return True
|
||||
except Exception:
|
||||
self._channels.pop(name, None)
|
||||
logger.exception("Failed to start channel %s", name)
|
||||
logger.exception("Failed to start channel")
|
||||
return False
|
||||
|
||||
def get_status(self) -> dict[str, Any]:
|
||||
@@ -219,7 +362,9 @@ async def start_channel_service(app_config: AppConfig | None = None) -> ChannelS
|
||||
global _channel_service
|
||||
if _channel_service is not None:
|
||||
return _channel_service
|
||||
_channel_service = ChannelService.from_app_config(app_config)
|
||||
# from_app_config reads the JSON channel store and runtime config files;
|
||||
# keep that disk IO off the event loop.
|
||||
_channel_service = await asyncio.to_thread(ChannelService.from_app_config, app_config)
|
||||
await _channel_service.start()
|
||||
return _channel_service
|
||||
|
||||
|
||||
+148
-26
@@ -9,7 +9,8 @@ from typing import Any
|
||||
from markdown_to_mrkdwn import SlackMarkdownConverter
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -64,6 +65,9 @@ class SlackChannel(Channel):
|
||||
self._web_client = None
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
self._web_client_factory = config.get("web_client_factory")
|
||||
self._connection_web_clients: dict[str, tuple[str, Any]] = {}
|
||||
configured_bot_user_id = config.get("bot_user_id")
|
||||
self._bot_user_id = str(configured_bot_user_id).lstrip("@") if configured_bot_user_id else None
|
||||
|
||||
@@ -80,26 +84,28 @@ class SlackChannel(Channel):
|
||||
return
|
||||
|
||||
self._SocketModeResponse = SocketModeResponse
|
||||
if self._web_client_factory is None:
|
||||
self._web_client_factory = WebClient
|
||||
|
||||
bot_token = self.config.get("bot_token", "")
|
||||
app_token = self.config.get("app_token", "")
|
||||
|
||||
if self._connection_repo is not None and self.config.get("event_delivery") == "http":
|
||||
if not bot_token:
|
||||
logger.error("Slack HTTP Events mode requires bot_token")
|
||||
return
|
||||
await self._initialize_operator_web_client(str(bot_token))
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self._running = True
|
||||
self.bus.subscribe_outbound(self._on_outbound)
|
||||
logger.info("Slack channel started in HTTP Events mode")
|
||||
return
|
||||
|
||||
if not bot_token or not app_token:
|
||||
logger.error("Slack channel requires bot_token and app_token")
|
||||
return
|
||||
|
||||
self._web_client = WebClient(token=bot_token)
|
||||
if self._bot_user_id is None:
|
||||
try:
|
||||
auth_info = await asyncio.to_thread(self._web_client.auth_test)
|
||||
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
|
||||
if user_id is None:
|
||||
auth_get = getattr(auth_info, "get", None)
|
||||
user_id = auth_get("user_id") if callable(auth_get) else None
|
||||
if isinstance(user_id, str) and user_id:
|
||||
self._bot_user_id = user_id
|
||||
except Exception:
|
||||
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
|
||||
await self._initialize_operator_web_client(str(bot_token))
|
||||
self._socket_client = SocketModeClient(
|
||||
app_token=app_token,
|
||||
web_client=self._web_client,
|
||||
@@ -124,7 +130,8 @@ class SlackChannel(Channel):
|
||||
logger.info("Slack channel stopped")
|
||||
|
||||
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||||
if not self._web_client:
|
||||
web_client = await self._get_web_client_for_message(msg)
|
||||
if not web_client:
|
||||
return
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
@@ -137,11 +144,12 @@ class SlackChannel(Channel):
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(_max_retries):
|
||||
try:
|
||||
await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs)
|
||||
await asyncio.to_thread(web_client.chat_postMessage, **kwargs)
|
||||
# Add a completion reaction to the thread root
|
||||
if msg.thread_ts:
|
||||
await asyncio.to_thread(
|
||||
self._add_reaction,
|
||||
self._add_reaction_with_client,
|
||||
web_client,
|
||||
msg.chat_id,
|
||||
msg.thread_ts,
|
||||
"white_check_mark",
|
||||
@@ -165,7 +173,8 @@ class SlackChannel(Channel):
|
||||
if msg.thread_ts:
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self._add_reaction,
|
||||
self._add_reaction_with_client,
|
||||
web_client,
|
||||
msg.chat_id,
|
||||
msg.thread_ts,
|
||||
"x",
|
||||
@@ -177,7 +186,8 @@ class SlackChannel(Channel):
|
||||
raise last_exc
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
if not self._web_client:
|
||||
web_client = await self._get_web_client_for_message(msg)
|
||||
if not web_client:
|
||||
return False
|
||||
|
||||
try:
|
||||
@@ -190,7 +200,7 @@ class SlackChannel(Channel):
|
||||
if msg.thread_ts:
|
||||
kwargs["thread_ts"] = msg.thread_ts
|
||||
|
||||
await asyncio.to_thread(self._web_client.files_upload_v2, **kwargs)
|
||||
await asyncio.to_thread(web_client.files_upload_v2, **kwargs)
|
||||
logger.info("[Slack] file uploaded: %s to channel=%s", attachment.filename, msg.chat_id)
|
||||
return True
|
||||
except Exception:
|
||||
@@ -199,12 +209,45 @@ class SlackChannel(Channel):
|
||||
|
||||
# -- internal ----------------------------------------------------------
|
||||
|
||||
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
|
||||
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
|
||||
if not self._web_client:
|
||||
async def _initialize_operator_web_client(self, bot_token: str) -> None:
|
||||
self._web_client = self._web_client_factory(token=bot_token)
|
||||
if self._bot_user_id is not None:
|
||||
return
|
||||
try:
|
||||
self._web_client.reactions_add(
|
||||
auth_info = await asyncio.to_thread(self._web_client.auth_test)
|
||||
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
|
||||
if user_id is None:
|
||||
auth_get = getattr(auth_info, "get", None)
|
||||
user_id = auth_get("user_id") if callable(auth_get) else None
|
||||
if isinstance(user_id, str) and user_id:
|
||||
self._bot_user_id = user_id
|
||||
except Exception:
|
||||
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
|
||||
|
||||
async def _get_web_client_for_message(self, msg: OutboundMessage):
|
||||
if msg.connection_id and self._connection_repo is not None:
|
||||
credentials = await self._connection_repo.get_credentials(msg.connection_id)
|
||||
access_token = credentials.get("access_token") if credentials else None
|
||||
if not access_token:
|
||||
return self._web_client
|
||||
# WebClient keeps its own HTTP session and rate-limit state, so
|
||||
# reuse one per connection until its token changes.
|
||||
cached = self._connection_web_clients.get(msg.connection_id)
|
||||
if cached is not None and cached[0] == access_token:
|
||||
return cached[1]
|
||||
if self._web_client_factory is None:
|
||||
from slack_sdk import WebClient
|
||||
|
||||
self._web_client_factory = WebClient
|
||||
web_client = self._web_client_factory(token=access_token)
|
||||
self._connection_web_clients[msg.connection_id] = (access_token, web_client)
|
||||
return web_client
|
||||
return self._web_client
|
||||
|
||||
@staticmethod
|
||||
def _add_reaction_with_client(web_client, channel_id: str, timestamp: str, emoji: str) -> None:
|
||||
try:
|
||||
web_client.reactions_add(
|
||||
channel=channel_id,
|
||||
timestamp=timestamp,
|
||||
name=emoji,
|
||||
@@ -213,6 +256,12 @@ class SlackChannel(Channel):
|
||||
if "already_reacted" not in str(exc):
|
||||
logger.warning("[Slack] failed to add reaction %s: %s", emoji, exc)
|
||||
|
||||
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
|
||||
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
|
||||
if not self._web_client:
|
||||
return
|
||||
self._add_reaction_with_client(self._web_client, channel_id, timestamp, emoji)
|
||||
|
||||
def _send_running_reply(self, channel_id: str, thread_ts: str) -> None:
|
||||
"""Send a 'Working on it......' reply in the thread (called from SDK thread)."""
|
||||
if not self._web_client:
|
||||
@@ -249,12 +298,15 @@ class SlackChannel(Channel):
|
||||
|
||||
# Handle message events (DM or @mention)
|
||||
if etype in ("message", "app_mention"):
|
||||
self._handle_message_event(event)
|
||||
self._handle_message_event(
|
||||
event,
|
||||
team_id=req.payload.get("team_id") or req.payload.get("team") or event.get("team"),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error processing Slack event")
|
||||
|
||||
def _handle_message_event(self, event: dict) -> None:
|
||||
def _handle_message_event(self, event: dict, *, team_id: str | None = None) -> None:
|
||||
# Ignore bot messages
|
||||
if event.get("bot_id") or event.get("subtype"):
|
||||
return
|
||||
@@ -272,6 +324,19 @@ class SlackChannel(Channel):
|
||||
if not text:
|
||||
return
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code:
|
||||
if self._loop and self._loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._bind_connection_from_connect_code(
|
||||
event=event,
|
||||
team_id=str(team_id or event.get("team") or ""),
|
||||
code=connect_code,
|
||||
),
|
||||
self._loop,
|
||||
)
|
||||
return
|
||||
|
||||
channel_id = event.get("channel", "")
|
||||
thread_ts = event.get("thread_ts") or event.get("ts", "")
|
||||
|
||||
@@ -297,4 +362,61 @@ class SlackChannel(Channel):
|
||||
self._add_reaction(channel_id, event.get("ts", thread_ts), "eyes")
|
||||
# Send "running" reply first (fire-and-forget from SDK thread)
|
||||
self._send_running_reply(channel_id, thread_ts)
|
||||
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
|
||||
if self._connection_repo is None:
|
||||
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
|
||||
else:
|
||||
asyncio.run_coroutine_threadsafe(self._publish_inbound_with_connection(inbound, team_id=team_id), self._loop)
|
||||
|
||||
async def _publish_inbound_with_connection(self, inbound, *, team_id: str | None = None) -> None:
|
||||
inbound = await self._attach_connection_identity(inbound, team_id=team_id)
|
||||
await self.bus.publish_inbound(inbound)
|
||||
|
||||
async def _attach_connection_identity(self, inbound, *, team_id: str | None = None):
|
||||
workspace_id = str(team_id or inbound.metadata.get("team_id") or "")
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="slack",
|
||||
workspace_id=workspace_id,
|
||||
)
|
||||
|
||||
async def _bind_connection_from_connect_code(self, *, event: dict, team_id: str, code: str) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
return False
|
||||
|
||||
channel_id = str(event.get("channel") or "")
|
||||
thread_ts = str(event.get("thread_ts") or event.get("ts") or "")
|
||||
state = await self._connection_repo.consume_oauth_state(provider="slack", state=code)
|
||||
if state is None:
|
||||
await self._post_connection_reply(channel_id, "Slack connection code is invalid or expired.", thread_ts)
|
||||
return True
|
||||
|
||||
user_id = str(event.get("user") or "")
|
||||
if not user_id or not team_id:
|
||||
await self._post_connection_reply(channel_id, "Slack connection could not be completed from this message.", thread_ts)
|
||||
return True
|
||||
|
||||
await self._connection_repo.upsert_connection(
|
||||
owner_user_id=state["owner_user_id"],
|
||||
provider="slack",
|
||||
external_account_id=user_id,
|
||||
workspace_id=team_id,
|
||||
metadata={
|
||||
"team_id": team_id,
|
||||
"channel_id": channel_id,
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
await self._post_connection_reply(channel_id, "Slack connected to DeerFlow.", thread_ts)
|
||||
return True
|
||||
|
||||
async def _post_connection_reply(self, channel_id: str, text: str, thread_ts: str | None = None) -> None:
|
||||
if not self._web_client or not channel_id:
|
||||
return
|
||||
kwargs: dict[str, Any] = {"channel": channel_id, "text": text}
|
||||
if thread_ts:
|
||||
kwargs["thread_ts"] = thread_ts
|
||||
try:
|
||||
await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs)
|
||||
except Exception:
|
||||
logger.exception("[Slack] failed to send connection reply in channel=%s", channel_id)
|
||||
|
||||
@@ -5,13 +5,27 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TELEGRAM_MAX_MESSAGE_LENGTH = 4096
|
||||
STREAM_EDIT_MIN_INTERVAL_SECONDS = 1.0
|
||||
# Groups (negative chat_id) are capped at 20 messages/minute by Telegram,
|
||||
# so stream edits there must pace well below the private-chat 1 msg/s guideline.
|
||||
STREAM_EDIT_GROUP_MIN_INTERVAL_SECONDS = 3.0
|
||||
# Bound on tracked in-flight streamed messages; entries normally clear on the
|
||||
# final update, this only guards against leaks when a final never arrives.
|
||||
MAX_TRACKED_STREAM_MESSAGES = 256
|
||||
|
||||
# Indirection so tests can patch the clock without touching the global time module.
|
||||
_monotonic = time.monotonic
|
||||
|
||||
|
||||
class TelegramChannel(Channel):
|
||||
"""Telegram bot channel using long-polling.
|
||||
@@ -35,6 +49,14 @@ class TelegramChannel(Channel):
|
||||
pass
|
||||
# chat_id -> last sent message_id for threaded replies
|
||||
self._last_bot_message: dict[str, int] = {}
|
||||
# stream_key ("chat_id:thread_ts") -> state of the in-flight streamed
|
||||
# bot message being edited in place: {"message_id", "last_edit_at", "last_text"}
|
||||
self._stream_messages: dict[str, dict[str, Any]] = {}
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
return True
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
@@ -102,10 +124,117 @@ class TelegramChannel(Channel):
|
||||
logger.error("Invalid Telegram chat_id: %s", msg.chat_id)
|
||||
return
|
||||
|
||||
kwargs: dict[str, Any] = {"chat_id": chat_id, "text": msg.text}
|
||||
key = self._stream_key(msg.chat_id, msg.thread_ts)
|
||||
|
||||
if not msg.is_final:
|
||||
await self._send_stream_update(chat_id, key, msg.text, reply_to=self._parse_message_id(msg.thread_ts))
|
||||
return
|
||||
|
||||
state = self._stream_messages.pop(key, None)
|
||||
if state is not None:
|
||||
await self._finalize_stream_message(chat_id, msg.chat_id, state, msg.text)
|
||||
return
|
||||
|
||||
await self._send_new_message(chat_id, msg.chat_id, msg.text, _max_retries=_max_retries)
|
||||
|
||||
async def _send_stream_update(self, chat_id: int, key: str, text: str, reply_to: int | None = None) -> None:
|
||||
"""Edit the in-flight streamed message with accumulated text.
|
||||
|
||||
Updates are best-effort: throttled, rate-limit drops are silent. The
|
||||
manager always publishes a final message afterwards, which guarantees
|
||||
delivery of the complete text.
|
||||
"""
|
||||
if not text:
|
||||
return
|
||||
|
||||
display = text
|
||||
if len(display) > TELEGRAM_MAX_MESSAGE_LENGTH:
|
||||
display = display[: TELEGRAM_MAX_MESSAGE_LENGTH - 1] + "…"
|
||||
|
||||
bot = self._application.bot
|
||||
state = self._stream_messages.get(key)
|
||||
|
||||
send_kwargs: dict[str, Any] = {"chat_id": chat_id, "text": display}
|
||||
if reply_to:
|
||||
send_kwargs["reply_to_message_id"] = reply_to
|
||||
|
||||
if state is None:
|
||||
try:
|
||||
sent = await bot.send_message(**send_kwargs)
|
||||
except Exception:
|
||||
logger.exception("[Telegram] failed to start stream message in chat=%s", chat_id)
|
||||
return
|
||||
self._register_stream_message(key, message_id=sent.message_id, last_text=display, last_edit_at=_monotonic())
|
||||
return
|
||||
|
||||
now = _monotonic()
|
||||
min_interval = STREAM_EDIT_GROUP_MIN_INTERVAL_SECONDS if chat_id < 0 else STREAM_EDIT_MIN_INTERVAL_SECONDS
|
||||
if now - state["last_edit_at"] < min_interval:
|
||||
return
|
||||
if display == state["last_text"]:
|
||||
return
|
||||
|
||||
try:
|
||||
await bot.edit_message_text(chat_id=chat_id, message_id=state["message_id"], text=display)
|
||||
except Exception as exc:
|
||||
if self._is_not_modified(exc):
|
||||
state["last_text"] = display
|
||||
return
|
||||
if self._is_retry_after(exc):
|
||||
logger.debug("[Telegram] stream edit rate-limited in chat=%s, dropping update", chat_id)
|
||||
return
|
||||
logger.warning("[Telegram] stream edit failed in chat=%s, sending new message: %s", chat_id, exc)
|
||||
try:
|
||||
sent = await bot.send_message(**send_kwargs)
|
||||
except Exception:
|
||||
logger.exception("[Telegram] failed to send fallback stream message in chat=%s", chat_id)
|
||||
return
|
||||
state["message_id"] = sent.message_id
|
||||
|
||||
state["last_edit_at"] = _monotonic()
|
||||
state["last_text"] = display
|
||||
|
||||
async def _finalize_stream_message(self, chat_id: int, chat_key: str, state: dict[str, Any], text: str) -> None:
|
||||
"""Apply the final text: edit the streamed message, splitting overflow into follow-ups."""
|
||||
bot = self._application.bot
|
||||
chunks = self._split_message(text or "")
|
||||
|
||||
edited = True
|
||||
if chunks[0] != state["last_text"]:
|
||||
edited = await self._edit_final_chunk(bot, chat_id, state["message_id"], chunks[0])
|
||||
|
||||
if edited:
|
||||
self._last_bot_message[chat_key] = state["message_id"]
|
||||
else:
|
||||
# Edit could not be applied (e.g. message deleted) — deliver the
|
||||
# first chunk as a fresh message with the standard retry policy.
|
||||
await self._send_new_message(chat_id, chat_key, chunks[0])
|
||||
|
||||
for chunk in chunks[1:]:
|
||||
await self._send_new_message(chat_id, chat_key, chunk)
|
||||
|
||||
async def _edit_final_chunk(self, bot, chat_id: int, message_id: int, text: str) -> bool:
|
||||
"""Edit with one rate-limit retry. Returns False if the edit could not be applied."""
|
||||
for attempt in range(2):
|
||||
try:
|
||||
await bot.edit_message_text(chat_id=chat_id, message_id=message_id, text=text)
|
||||
return True
|
||||
except Exception as exc:
|
||||
if self._is_not_modified(exc):
|
||||
return True
|
||||
if self._is_retry_after(exc) and attempt == 0:
|
||||
await asyncio.sleep(self._retry_after_seconds(exc))
|
||||
continue
|
||||
logger.warning("[Telegram] final edit failed in chat=%s: %s", chat_id, exc)
|
||||
return False
|
||||
return False
|
||||
|
||||
async def _send_new_message(self, chat_id: int, chat_key: str, text: str, *, _max_retries: int = 3) -> int | None:
|
||||
"""Send a fresh message with retry/backoff. Returns the sent message_id."""
|
||||
kwargs: dict[str, Any] = {"chat_id": chat_id, "text": text}
|
||||
|
||||
# Reply to the last bot message in this chat for threading
|
||||
reply_to = self._last_bot_message.get(msg.chat_id)
|
||||
reply_to = self._last_bot_message.get(chat_key)
|
||||
if reply_to:
|
||||
kwargs["reply_to_message_id"] = reply_to
|
||||
|
||||
@@ -114,8 +243,8 @@ class TelegramChannel(Channel):
|
||||
for attempt in range(_max_retries):
|
||||
try:
|
||||
sent = await bot.send_message(**kwargs)
|
||||
self._last_bot_message[msg.chat_id] = sent.message_id
|
||||
return
|
||||
self._last_bot_message[chat_key] = sent.message_id
|
||||
return sent.message_id
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
if attempt < _max_retries - 1:
|
||||
@@ -178,17 +307,63 @@ class TelegramChannel(Channel):
|
||||
|
||||
# -- helpers -----------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _stream_key(chat_id: str, thread_ts: str | None) -> str:
|
||||
return f"{chat_id}:{thread_ts or ''}"
|
||||
|
||||
@staticmethod
|
||||
def _parse_message_id(value: str | None) -> int | None:
|
||||
try:
|
||||
return int(value) if value else None
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
def _register_stream_message(self, key: str, *, message_id: int, last_text: str, last_edit_at: float) -> None:
|
||||
self._stream_messages.pop(key, None)
|
||||
while len(self._stream_messages) >= MAX_TRACKED_STREAM_MESSAGES:
|
||||
self._stream_messages.pop(next(iter(self._stream_messages)))
|
||||
self._stream_messages[key] = {
|
||||
"message_id": message_id,
|
||||
"last_edit_at": last_edit_at,
|
||||
"last_text": last_text,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _is_retry_after(exc: Exception) -> bool:
|
||||
return getattr(exc, "retry_after", None) is not None
|
||||
|
||||
@staticmethod
|
||||
def _retry_after_seconds(exc: Exception) -> float:
|
||||
value = getattr(exc, "retry_after", 0)
|
||||
if hasattr(value, "total_seconds"):
|
||||
return float(value.total_seconds())
|
||||
return float(value)
|
||||
|
||||
@staticmethod
|
||||
def _is_not_modified(exc: Exception) -> bool:
|
||||
return "message is not modified" in str(exc).lower()
|
||||
|
||||
@staticmethod
|
||||
def _split_message(text: str) -> list[str]:
|
||||
return [text[i : i + TELEGRAM_MAX_MESSAGE_LENGTH] for i in range(0, len(text), TELEGRAM_MAX_MESSAGE_LENGTH)] or [text]
|
||||
|
||||
async def _send_running_reply(self, chat_id: str, reply_to_message_id: int) -> None:
|
||||
"""Send a 'Working on it...' reply to the user's message."""
|
||||
"""Send a 'Working on it...' reply and register it as the stream target."""
|
||||
if not self._application:
|
||||
return
|
||||
try:
|
||||
bot = self._application.bot
|
||||
await bot.send_message(
|
||||
sent = await bot.send_message(
|
||||
chat_id=int(chat_id),
|
||||
text="Working on it...",
|
||||
reply_to_message_id=reply_to_message_id,
|
||||
)
|
||||
self._register_stream_message(
|
||||
self._stream_key(chat_id, str(reply_to_message_id)),
|
||||
message_id=sent.message_id,
|
||||
last_text="Working on it...",
|
||||
last_edit_at=0.0,
|
||||
)
|
||||
logger.info("[Telegram] 'Working on it...' reply sent in chat=%s", chat_id)
|
||||
except Exception:
|
||||
logger.exception("[Telegram] failed to send running reply in chat=%s", chat_id)
|
||||
@@ -233,6 +408,54 @@ class TelegramChannel(Channel):
|
||||
return True
|
||||
return user_id in self._allowed_users
|
||||
|
||||
@staticmethod
|
||||
def _telegram_display_name(user) -> str:
|
||||
full_name = getattr(user, "full_name", None)
|
||||
if isinstance(full_name, str) and full_name:
|
||||
return full_name
|
||||
username = getattr(user, "username", None)
|
||||
if isinstance(username, str) and username:
|
||||
return username
|
||||
return str(getattr(user, "id", ""))
|
||||
|
||||
async def _bind_connection_from_start_token(self, update, state_token: str) -> bool:
|
||||
if self._connection_repo is None or not state_token:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="telegram", state=state_token)
|
||||
if state is None:
|
||||
await update.message.reply_text("Telegram connection link is invalid or expired.")
|
||||
return True
|
||||
|
||||
owner_user_id = state["owner_user_id"]
|
||||
user_id = str(update.effective_user.id)
|
||||
chat_id = str(update.effective_chat.id)
|
||||
connection = await self._connection_repo.upsert_connection(
|
||||
owner_user_id=owner_user_id,
|
||||
provider="telegram",
|
||||
external_account_id=user_id,
|
||||
external_account_name=self._telegram_display_name(update.effective_user),
|
||||
workspace_id=chat_id,
|
||||
workspace_name=None,
|
||||
metadata={
|
||||
"chat_id": chat_id,
|
||||
"chat_type": update.effective_chat.type,
|
||||
"telegram_username": getattr(update.effective_user, "username", None),
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
logger.info("[Telegram] bound chat=%s user=%s to DeerFlow user=%s connection=%s", chat_id, user_id, owner_user_id, connection["id"])
|
||||
await update.message.reply_text("Telegram connected to DeerFlow.")
|
||||
return True
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="telegram",
|
||||
workspace_id=inbound.chat_id,
|
||||
)
|
||||
|
||||
def _get_bot_username(self, context) -> str | None:
|
||||
bot = getattr(context, "bot", None)
|
||||
username = getattr(bot, "username", None)
|
||||
@@ -264,6 +487,11 @@ class TelegramChannel(Channel):
|
||||
"""Handle /start command."""
|
||||
if not self._check_user(update.effective_user.id):
|
||||
return
|
||||
args = getattr(context, "args", []) if context is not None else []
|
||||
if args:
|
||||
handled = await self._bind_connection_from_start_token(update, str(args[0]))
|
||||
if handled:
|
||||
return
|
||||
await update.message.reply_text("Welcome to DeerFlow! Send me a message to start a conversation.\nType /help for available commands.")
|
||||
|
||||
async def _process_incoming_with_reply(self, chat_id: str, msg_id: int, inbound: InboundMessage) -> None:
|
||||
@@ -299,6 +527,7 @@ class TelegramChannel(Channel):
|
||||
thread_ts=msg_id,
|
||||
)
|
||||
inbound.topic_id = topic_id
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
||||
@@ -341,6 +570,7 @@ class TelegramChannel(Channel):
|
||||
thread_ts=msg_id,
|
||||
)
|
||||
inbound.topic_id = topic_id
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
||||
|
||||
@@ -22,8 +22,9 @@ from cryptography.hazmat.primitives import padding
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -253,6 +254,7 @@ class WechatChannel(Channel):
|
||||
self._state_dir = self._resolve_state_dir(config.get("state_dir"))
|
||||
self._cursor_path = self._state_dir / "wechat-getupdates.json" if self._state_dir else None
|
||||
self._auth_path = self._state_dir / "wechat-auth.json" if self._state_dir else None
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
self._load_state()
|
||||
|
||||
async def start(self) -> None:
|
||||
@@ -617,6 +619,16 @@ class WechatChannel(Channel):
|
||||
if thread_ts:
|
||||
self._context_tokens_by_thread[thread_ts] = context_token
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and self._connection_repo is not None:
|
||||
handled = await self._bind_connection_from_connect_code(
|
||||
chat_id=chat_id,
|
||||
context_token=context_token,
|
||||
code=connect_code,
|
||||
)
|
||||
if handled:
|
||||
return
|
||||
|
||||
inbound = self._make_inbound(
|
||||
chat_id=chat_id,
|
||||
user_id=chat_id,
|
||||
@@ -632,8 +644,54 @@ class WechatChannel(Channel):
|
||||
},
|
||||
)
|
||||
inbound.topic_id = None
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
await self.bus.publish_inbound(inbound)
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="wechat",
|
||||
workspace_id=inbound.chat_id,
|
||||
)
|
||||
|
||||
async def _bind_connection_from_connect_code(self, *, chat_id: str, context_token: str, code: str) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="wechat", state=code)
|
||||
if state is None:
|
||||
await self._send_connection_reply(chat_id, context_token, "WeChat connection code is invalid or expired.")
|
||||
return True
|
||||
|
||||
if not chat_id:
|
||||
await self._send_connection_reply(chat_id, context_token, "WeChat connection could not be completed from this message.")
|
||||
return True
|
||||
|
||||
await self._connection_repo.upsert_connection(
|
||||
owner_user_id=state["owner_user_id"],
|
||||
provider="wechat",
|
||||
external_account_id=chat_id,
|
||||
workspace_id=chat_id,
|
||||
metadata={
|
||||
"context_token": context_token,
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
await self._send_connection_reply(chat_id, context_token, "WeChat connected to DeerFlow.")
|
||||
return True
|
||||
|
||||
async def _send_connection_reply(self, chat_id: str, context_token: str, text: str) -> None:
|
||||
if not context_token:
|
||||
return
|
||||
await self._send_text_message(
|
||||
chat_id=chat_id,
|
||||
context_token=context_token,
|
||||
text=text,
|
||||
client_id_prefix="deerflow-connect",
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
async def _ensure_authenticated(self) -> bool:
|
||||
async with self._auth_lock:
|
||||
if self._bot_token:
|
||||
|
||||
@@ -8,8 +8,10 @@ from collections.abc import Awaitable, Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import (
|
||||
InboundMessage,
|
||||
InboundMessageType,
|
||||
MessageBus,
|
||||
OutboundMessage,
|
||||
@@ -29,6 +31,7 @@ class WeComChannel(Channel):
|
||||
self._ws_frames: dict[str, dict[str, Any]] = {}
|
||||
self._ws_stream_ids: dict[str, str] = {}
|
||||
self._working_message = "Working on it..."
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
@@ -79,12 +82,33 @@ class WeComChannel(Channel):
|
||||
self._ws_client.on("message.mixed", self._on_ws_mixed)
|
||||
self._ws_client.on("message.image", self._on_ws_image)
|
||||
self._ws_client.on("message.file", self._on_ws_file)
|
||||
self._ws_client.on("error", self._on_ws_error)
|
||||
self._ws_client.on("disconnected", self._on_ws_disconnected)
|
||||
self._ws_task = asyncio.create_task(self._ws_client.connect())
|
||||
self._ws_task.add_done_callback(self._on_ws_task_done)
|
||||
|
||||
self._running = True
|
||||
self.bus.subscribe_outbound(self._on_outbound)
|
||||
logger.info("WeCom channel started")
|
||||
|
||||
def _on_ws_task_done(self, task: asyncio.Task) -> None:
|
||||
if task.cancelled():
|
||||
return
|
||||
exc = task.exception()
|
||||
if exc is None:
|
||||
return
|
||||
logger.error(
|
||||
"WeCom WebSocket connection task failed: %s. Check that the network/proxy allows wss://openws.work.weixin.qq.com and that bot_id/bot_secret are valid.",
|
||||
exc,
|
||||
)
|
||||
|
||||
def _on_ws_error(self, error: Any) -> None:
|
||||
logger.error("WeCom WebSocket error: %s", error)
|
||||
|
||||
def _on_ws_disconnected(self, *args: Any) -> None:
|
||||
detail = f" ({args[0]})" if args else ""
|
||||
logger.warning("WeCom WebSocket disconnected%s; SDK will attempt to reconnect", detail)
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
self.bus.unsubscribe_outbound(self._on_outbound)
|
||||
@@ -271,6 +295,16 @@ class WeComChannel(Channel):
|
||||
|
||||
user_id = (body.get("from") or {}).get("userid")
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and self._connection_repo is not None:
|
||||
handled = await self._bind_connection_from_connect_code(
|
||||
frame=frame,
|
||||
user_id=str(user_id or ""),
|
||||
code=connect_code,
|
||||
)
|
||||
if handled:
|
||||
return
|
||||
|
||||
inbound_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
|
||||
inbound = self._make_inbound(
|
||||
chat_id=user_id, # keep user's conversation in memory
|
||||
@@ -292,8 +326,52 @@ class WeComChannel(Channel):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
await self.bus.publish_inbound(inbound)
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="wecom",
|
||||
workspace_id=str(inbound.metadata.get("aibotid") or "") or None,
|
||||
fallback_without_workspace=True,
|
||||
)
|
||||
|
||||
async def _bind_connection_from_connect_code(self, *, frame: dict[str, Any], user_id: str, code: str) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="wecom", state=code)
|
||||
if state is None:
|
||||
await self._send_connection_reply(frame, "WeCom connection code is invalid or expired.")
|
||||
return True
|
||||
|
||||
if not user_id:
|
||||
await self._send_connection_reply(frame, "WeCom connection could not be completed from this message.")
|
||||
return True
|
||||
|
||||
body = frame.get("body", {}) or {}
|
||||
workspace_id = str(body.get("aibotid") or "") or None
|
||||
await self._connection_repo.upsert_connection(
|
||||
owner_user_id=state["owner_user_id"],
|
||||
provider="wecom",
|
||||
external_account_id=user_id,
|
||||
workspace_id=workspace_id,
|
||||
metadata={
|
||||
"aibotid": workspace_id,
|
||||
"chattype": body.get("chattype"),
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
await self._send_connection_reply(frame, "WeCom connected to DeerFlow.")
|
||||
return True
|
||||
|
||||
async def _send_connection_reply(self, frame: dict[str, Any], text: str) -> None:
|
||||
if not self._ws_client:
|
||||
return
|
||||
await self._ws_client.reply(frame, {"msgtype": "text", "text": {"content": text}})
|
||||
|
||||
async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||||
if not self._ws_client:
|
||||
return
|
||||
|
||||
+26
-14
@@ -6,6 +6,7 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.gateway.auth_disabled import warn_if_auth_disabled_enabled
|
||||
from app.gateway.auth_middleware import AuthMiddleware
|
||||
from app.gateway.config import get_gateway_config
|
||||
from app.gateway.csrf_middleware import CSRFMiddleware, get_configured_cors_origins
|
||||
@@ -15,6 +16,7 @@ from app.gateway.routers import (
|
||||
artifacts,
|
||||
assistants_compat,
|
||||
auth,
|
||||
channel_connections,
|
||||
channels,
|
||||
feedback,
|
||||
mcp,
|
||||
@@ -172,6 +174,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
startup_config = get_app_config()
|
||||
apply_logging_level(startup_config.log_level)
|
||||
logger.info("Configuration loaded successfully")
|
||||
warn_if_auth_disabled_enabled()
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to load configuration during gateway startup: {e}"
|
||||
logger.exception(error_msg)
|
||||
@@ -182,21 +185,27 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
# Pre-warm tiktoken encoding cache so the first memory-injection request
|
||||
# never blocks on the BPE data download (which hits an OpenAI/Azure URL
|
||||
# that may be unreachable in restricted networks — see issue #3402).
|
||||
try:
|
||||
from deerflow.agents.memory.prompt import warm_tiktoken_cache
|
||||
# When memory.token_counting is "char", token counting never touches
|
||||
# tiktoken, so skip the warm-up entirely (avoids even the 5s probe in
|
||||
# network-restricted deployments — see issue #3429).
|
||||
if startup_config.memory.token_counting == "char":
|
||||
logger.info("memory.token_counting='char'; skipping tiktoken warm-up (network-free token estimation)")
|
||||
else:
|
||||
try:
|
||||
from deerflow.agents.memory.prompt import warm_tiktoken_cache
|
||||
|
||||
warmed = await asyncio.wait_for(
|
||||
asyncio.to_thread(warm_tiktoken_cache),
|
||||
timeout=5,
|
||||
)
|
||||
if warmed:
|
||||
logger.info("tiktoken encoding cache warmed successfully")
|
||||
else:
|
||||
logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback")
|
||||
except TimeoutError:
|
||||
logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback")
|
||||
except Exception:
|
||||
logger.warning("tiktoken warm-up skipped", exc_info=True)
|
||||
warmed = await asyncio.wait_for(
|
||||
asyncio.to_thread(warm_tiktoken_cache),
|
||||
timeout=5,
|
||||
)
|
||||
if warmed:
|
||||
logger.info("tiktoken encoding cache warmed successfully")
|
||||
else:
|
||||
logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback until tiktoken loads successfully")
|
||||
except TimeoutError:
|
||||
logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback until tiktoken loads successfully")
|
||||
except Exception:
|
||||
logger.warning("tiktoken warm-up skipped", exc_info=True)
|
||||
|
||||
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
|
||||
async with langgraph_runtime(app, startup_config):
|
||||
@@ -376,6 +385,9 @@ This gateway provides runtime endpoints for agent runs plus custom endpoints for
|
||||
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
|
||||
app.include_router(suggestions.router)
|
||||
|
||||
# User-facing IM channel connection API is mounted at /api/channels
|
||||
app.include_router(channel_connections.router)
|
||||
|
||||
# Channels API is mounted at /api/channels
|
||||
app.include_router(channels.router)
|
||||
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
"""Shared helpers for local/E2E auth-disabled mode."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
|
||||
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
||||
|
||||
AUTH_DISABLED_ENV_VAR = "DEER_FLOW_AUTH_DISABLED"
|
||||
AUTH_DISABLED_USER_ID = DEFAULT_USER_ID
|
||||
AUTH_DISABLED_USER_EMAIL = "default@test.local"
|
||||
|
||||
AUTH_SOURCE_SESSION = "session"
|
||||
AUTH_SOURCE_INTERNAL = "internal"
|
||||
AUTH_SOURCE_AUTH_DISABLED = "auth_disabled"
|
||||
|
||||
_PRODUCTION_ENV_VARS: tuple[str, ...] = ("DEER_FLOW_ENV", "ENVIRONMENT")
|
||||
_PRODUCTION_ENV_VALUES: frozenset[str] = frozenset({"prod", "production"})
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_explicit_production_environment() -> bool:
|
||||
return any(os.environ.get(name, "").strip().lower() in _PRODUCTION_ENV_VALUES for name in _PRODUCTION_ENV_VARS)
|
||||
|
||||
|
||||
def is_auth_disabled_requested() -> bool:
|
||||
return os.environ.get(AUTH_DISABLED_ENV_VAR) == "1"
|
||||
|
||||
|
||||
def is_auth_disabled() -> bool:
|
||||
return is_auth_disabled_requested() and not is_explicit_production_environment()
|
||||
|
||||
|
||||
def warn_if_auth_disabled_enabled() -> None:
|
||||
if not is_auth_disabled():
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"%s=1 is active: authentication is bypassed and anonymous requests run as synthetic admin user %r. Do not enable this in shared or production deployments.",
|
||||
AUTH_DISABLED_ENV_VAR,
|
||||
AUTH_DISABLED_USER_ID,
|
||||
)
|
||||
|
||||
|
||||
def get_auth_disabled_user():
|
||||
return SimpleNamespace(
|
||||
id=AUTH_DISABLED_USER_ID,
|
||||
email=AUTH_DISABLED_USER_EMAIL,
|
||||
password_hash=None,
|
||||
system_role="admin",
|
||||
needs_setup=False,
|
||||
token_version=0,
|
||||
)
|
||||
@@ -17,6 +17,13 @@ from starlette.responses import JSONResponse
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
||||
from app.gateway.auth_disabled import (
|
||||
AUTH_SOURCE_AUTH_DISABLED,
|
||||
AUTH_SOURCE_INTERNAL,
|
||||
AUTH_SOURCE_SESSION,
|
||||
get_auth_disabled_user,
|
||||
is_auth_disabled,
|
||||
)
|
||||
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
|
||||
from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token
|
||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||
@@ -80,8 +87,38 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)):
|
||||
internal_user = get_internal_user()
|
||||
|
||||
auth_source = AUTH_SOURCE_SESSION
|
||||
access_token = request.cookies.get("access_token")
|
||||
|
||||
# Non-public path: require session cookie
|
||||
if internal_user is None and not request.cookies.get("access_token"):
|
||||
if internal_user is not None:
|
||||
user = internal_user
|
||||
auth_source = AUTH_SOURCE_INTERNAL
|
||||
elif access_token:
|
||||
# Strict JWT validation: reject junk/expired tokens with 401
|
||||
# right here instead of silently passing through. This closes
|
||||
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
|
||||
# without this, non-isolation routes like /api/models would
|
||||
# accept any cookie-shaped string as authentication.
|
||||
#
|
||||
# We call the *strict* resolver so that fine-grained error
|
||||
# codes (token_expired, token_invalid, user_not_found, …)
|
||||
# propagate from AuthErrorCode, not get flattened into one
|
||||
# generic code. BaseHTTPMiddleware doesn't let HTTPException
|
||||
# bubble up, so we catch and render it as JSONResponse here.
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
try:
|
||||
user = await get_current_user_from_request(request)
|
||||
except HTTPException as exc:
|
||||
if not is_auth_disabled():
|
||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||
user = get_auth_disabled_user()
|
||||
auth_source = AUTH_SOURCE_AUTH_DISABLED
|
||||
elif is_auth_disabled():
|
||||
user = get_auth_disabled_user()
|
||||
auth_source = AUTH_SOURCE_AUTH_DISABLED
|
||||
else:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
@@ -92,32 +129,12 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
},
|
||||
)
|
||||
|
||||
# Strict JWT validation: reject junk/expired tokens with 401
|
||||
# right here instead of silently passing through. This closes
|
||||
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
|
||||
# without this, non-isolation routes like /api/models would
|
||||
# accept any cookie-shaped string as authentication.
|
||||
#
|
||||
# We call the *strict* resolver so that fine-grained error
|
||||
# codes (token_expired, token_invalid, user_not_found, …)
|
||||
# propagate from AuthErrorCode, not get flattened into one
|
||||
# generic code. BaseHTTPMiddleware doesn't let HTTPException
|
||||
# bubble up, so we catch and render it as JSONResponse here.
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
if internal_user is not None:
|
||||
user = internal_user
|
||||
else:
|
||||
try:
|
||||
user = await get_current_user_from_request(request)
|
||||
except HTTPException as exc:
|
||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||
|
||||
# Stamp both request.state.user (for the contextvar pattern)
|
||||
# and request.state.auth (so @require_permission's "auth is
|
||||
# None" branch short-circuits instead of running the entire
|
||||
# JWT-decode + DB-lookup pipeline a second time per request).
|
||||
request.state.user = user
|
||||
request.state.auth_source = auth_source
|
||||
request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
||||
token = set_current_user(user)
|
||||
try:
|
||||
|
||||
@@ -276,6 +276,8 @@ def require_permission(
|
||||
# strict-deny rather than strict-allow — only an *existing*
|
||||
# row with a *different* user_id triggers 404.
|
||||
if owner_check:
|
||||
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
|
||||
|
||||
thread_id = kwargs.get("thread_id")
|
||||
if thread_id is None:
|
||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||
@@ -288,6 +290,22 @@ def require_permission(
|
||||
str(auth.user.id),
|
||||
require_existing=require_existing,
|
||||
)
|
||||
if not allowed and getattr(auth.user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
|
||||
# Trusted internal callers (channel workers) also act for
|
||||
# the connection owner carried in X-DeerFlow-Owner-User-Id.
|
||||
# Scope the check to that owner instead of bypassing it; a
|
||||
# leaked internal token must not grant cross-user thread
|
||||
# access. The header is honored only after ``auth`` proved
|
||||
# the caller holds the internal token (mirrors
|
||||
# get_trusted_internal_owner_user_id, which keys off the
|
||||
# middleware-stamped ``request.state.user``).
|
||||
header_owner = (request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME) or "").strip()
|
||||
if header_owner:
|
||||
allowed = await thread_store.check_access(
|
||||
thread_id,
|
||||
header_owner,
|
||||
require_existing=require_existing,
|
||||
)
|
||||
if not allowed:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
|
||||
@@ -14,6 +14,8 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from app.gateway.auth_disabled import is_auth_disabled
|
||||
|
||||
CSRF_COOKIE_NAME = "csrf_token"
|
||||
CSRF_HEADER_NAME = "X-CSRF-Token"
|
||||
CSRF_TOKEN_LENGTH = 64 # bytes
|
||||
@@ -38,6 +40,9 @@ def should_check_csrf(request: Request) -> bool:
|
||||
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
|
||||
return False
|
||||
|
||||
if is_auth_disabled():
|
||||
return False
|
||||
|
||||
path = request.url.path.rstrip("/")
|
||||
# Exempt /api/v1/auth/me endpoint
|
||||
if path == "/api/v1/auth/me":
|
||||
|
||||
@@ -331,6 +331,17 @@ async def get_current_user_from_request(request: Request):
|
||||
|
||||
Raises HTTPException 401 if not authenticated.
|
||||
"""
|
||||
state = getattr(request, "state", None)
|
||||
state_user = getattr(state, "user", None)
|
||||
from app.gateway.auth_disabled import AUTH_SOURCE_AUTH_DISABLED, AUTH_SOURCE_INTERNAL, AUTH_SOURCE_SESSION
|
||||
|
||||
if state_user is not None and getattr(state, "auth_source", None) in {
|
||||
AUTH_SOURCE_SESSION,
|
||||
AUTH_SOURCE_AUTH_DISABLED,
|
||||
AUTH_SOURCE_INTERNAL,
|
||||
}:
|
||||
return state_user
|
||||
|
||||
from app.gateway.auth import decode_token
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
|
||||
|
||||
|
||||
@@ -5,10 +5,12 @@ from __future__ import annotations
|
||||
import os
|
||||
import secrets
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
||||
|
||||
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
|
||||
INTERNAL_OWNER_USER_ID_HEADER_NAME = "X-DeerFlow-Owner-User-Id"
|
||||
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
|
||||
INTERNAL_SYSTEM_ROLE = "internal"
|
||||
|
||||
@@ -23,9 +25,12 @@ def _load_internal_auth_token() -> str:
|
||||
_INTERNAL_AUTH_TOKEN = _load_internal_auth_token()
|
||||
|
||||
|
||||
def create_internal_auth_headers() -> dict[str, str]:
|
||||
def create_internal_auth_headers(*, owner_user_id: str | None = None) -> dict[str, str]:
|
||||
"""Return headers that authenticate trusted Gateway internal calls."""
|
||||
return {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
|
||||
headers = {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
|
||||
if owner_user_id:
|
||||
headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] = owner_user_id
|
||||
return headers
|
||||
|
||||
|
||||
def is_valid_internal_auth_token(token: str | None) -> bool:
|
||||
@@ -36,3 +41,21 @@ def is_valid_internal_auth_token(token: str | None) -> bool:
|
||||
def get_internal_user():
|
||||
"""Return the synthetic user used for trusted internal channel calls."""
|
||||
return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE)
|
||||
|
||||
|
||||
def get_trusted_internal_owner_user_id(request: Any) -> str | None:
|
||||
"""Return the owner override for a trusted internal request, if present.
|
||||
|
||||
The header is ignored for normal browser/API callers. It is only honored
|
||||
after ``AuthMiddleware`` has validated the internal auth token and stamped
|
||||
the synthetic internal user onto ``request.state.user``.
|
||||
"""
|
||||
user = getattr(getattr(request, "state", None), "user", None)
|
||||
if getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE:
|
||||
return None
|
||||
|
||||
owner_user_id = request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME)
|
||||
if not owner_user_id:
|
||||
return None
|
||||
owner_user_id = owner_user_id.strip()
|
||||
return owner_user_id or None
|
||||
|
||||
@@ -20,6 +20,7 @@ from langgraph_sdk import Auth
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
from app.gateway.auth.jwt import decode_token
|
||||
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled
|
||||
from app.gateway.deps import get_local_provider
|
||||
|
||||
auth = Auth()
|
||||
@@ -38,6 +39,9 @@ def _check_csrf(request) -> None:
|
||||
if method.upper() not in _CSRF_METHODS:
|
||||
return
|
||||
|
||||
if is_auth_disabled():
|
||||
return
|
||||
|
||||
cookie_token = request.cookies.get("csrf_token")
|
||||
header_token = request.headers.get("x-csrf-token")
|
||||
|
||||
@@ -66,6 +70,9 @@ async def authenticate(request):
|
||||
# are rejected early, even if the cookie carries a valid JWT.
|
||||
_check_csrf(request)
|
||||
|
||||
if is_auth_disabled():
|
||||
return AUTH_DISABLED_USER_ID
|
||||
|
||||
token = request.cookies.get("access_token")
|
||||
if not token:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
|
||||
@@ -341,9 +341,19 @@ async def change_password(request: Request, response: Response, body: ChangePass
|
||||
- Re-issues session cookie with new token_version
|
||||
"""
|
||||
from app.gateway.auth.password import hash_password_async, verify_password_async
|
||||
from app.gateway.auth_disabled import AUTH_SOURCE_AUTH_DISABLED
|
||||
|
||||
user = await get_current_user_from_request(request)
|
||||
|
||||
if getattr(request.state, "auth_source", None) == AUTH_SOURCE_AUTH_DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=AuthErrorResponse(
|
||||
code=AuthErrorCode.INVALID_CREDENTIALS,
|
||||
message="Password changes are not available when DEER_FLOW_AUTH_DISABLED=1.",
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
if user.password_hash is None:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
|
||||
|
||||
|
||||
@@ -0,0 +1,670 @@
|
||||
"""Browser-facing APIs for user-owned IM channel bindings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, Response
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.channels.runtime_config_store import (
|
||||
ChannelRuntimeConfigStore,
|
||||
apply_runtime_connection_config,
|
||||
merge_runtime_channel_configs,
|
||||
)
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
|
||||
router = APIRouter(prefix="/api/channels", tags=["channel-connections"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_STATE_TTL_SECONDS = 600
|
||||
_MASKED_CREDENTIAL_VALUE = "********"
|
||||
|
||||
|
||||
class ChannelCredentialFieldResponse(BaseModel):
|
||||
name: str
|
||||
label: str
|
||||
type: str = "text"
|
||||
required: bool = True
|
||||
|
||||
|
||||
class ChannelProviderResponse(BaseModel):
|
||||
provider: str
|
||||
display_name: str
|
||||
enabled: bool
|
||||
configured: bool
|
||||
connectable: bool
|
||||
unavailable_reason: str | None = None
|
||||
auth_mode: str
|
||||
connection_status: str
|
||||
credential_fields: list[ChannelCredentialFieldResponse] = Field(default_factory=list)
|
||||
credential_values: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ChannelProvidersResponse(BaseModel):
|
||||
enabled: bool
|
||||
providers: list[ChannelProviderResponse]
|
||||
|
||||
|
||||
class ChannelConnectionResponse(BaseModel):
|
||||
id: str
|
||||
provider: str
|
||||
status: str
|
||||
external_account_id: str | None = None
|
||||
external_account_name: str | None = None
|
||||
workspace_id: str | None = None
|
||||
workspace_name: str | None = None
|
||||
scopes: list[str] = Field(default_factory=list)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ChannelConnectionsResponse(BaseModel):
|
||||
connections: list[ChannelConnectionResponse]
|
||||
|
||||
|
||||
class ChannelConnectResponse(BaseModel):
|
||||
provider: str
|
||||
mode: str
|
||||
url: str | None = None
|
||||
code: str
|
||||
instruction: str
|
||||
expires_in: int
|
||||
|
||||
|
||||
class ChannelRuntimeConfigRequest(BaseModel):
|
||||
values: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
_PROVIDER_META: dict[str, dict[str, str]] = {
|
||||
"telegram": {"display_name": "Telegram", "auth_mode": "deep_link"},
|
||||
"slack": {"display_name": "Slack", "auth_mode": "binding_code"},
|
||||
"discord": {"display_name": "Discord", "auth_mode": "binding_code"},
|
||||
"feishu": {"display_name": "Feishu", "auth_mode": "binding_code"},
|
||||
"dingtalk": {"display_name": "DingTalk", "auth_mode": "binding_code"},
|
||||
"wechat": {"display_name": "WeChat", "auth_mode": "binding_code"},
|
||||
"wecom": {"display_name": "WeCom", "auth_mode": "binding_code"},
|
||||
}
|
||||
|
||||
_CREDENTIAL_FIELDS: dict[str, tuple[dict[str, str], ...]] = {
|
||||
"telegram": (
|
||||
{"name": "bot_token", "label": "Bot token", "type": "password"},
|
||||
{"name": "bot_username", "label": "Bot username", "type": "text"},
|
||||
),
|
||||
"slack": (
|
||||
{"name": "bot_token", "label": "Bot token", "type": "password"},
|
||||
{"name": "app_token", "label": "App token", "type": "password"},
|
||||
),
|
||||
"discord": ({"name": "bot_token", "label": "Bot token", "type": "password"},),
|
||||
"feishu": (
|
||||
{"name": "app_id", "label": "App ID", "type": "text"},
|
||||
{"name": "app_secret", "label": "App secret", "type": "password"},
|
||||
),
|
||||
"dingtalk": (
|
||||
{"name": "client_id", "label": "Client ID", "type": "text"},
|
||||
{"name": "client_secret", "label": "Client secret", "type": "password"},
|
||||
),
|
||||
"wechat": ({"name": "bot_token", "label": "Bot token", "type": "password"},),
|
||||
"wecom": (
|
||||
{"name": "bot_id", "label": "Bot ID", "type": "text"},
|
||||
{"name": "bot_secret", "label": "Bot secret", "type": "password"},
|
||||
),
|
||||
}
|
||||
|
||||
_RUNTIME_REQUIREMENTS: dict[str, tuple[str, ...]] = {
|
||||
"telegram": ("bot_token",),
|
||||
"slack": ("bot_token", "app_token"),
|
||||
"discord": ("bot_token",),
|
||||
"feishu": ("app_id", "app_secret"),
|
||||
"dingtalk": ("client_id", "client_secret"),
|
||||
"wechat": ("bot_token",),
|
||||
"wecom": ("bot_id", "bot_secret"),
|
||||
}
|
||||
|
||||
|
||||
def _get_user_id(request: Request) -> str:
|
||||
user = getattr(request.state, "user", None)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
return str(user.id)
|
||||
|
||||
|
||||
async def _require_admin_user(request: Request) -> None:
|
||||
"""Require an admin caller for instance-wide channel runtime mutations.
|
||||
|
||||
Runtime credentials and the channel workers they start/stop are shared by
|
||||
every user of the deployment, so only admins may change them (same model
|
||||
as the MCP config API). Auth-disabled local mode uses a synthetic admin
|
||||
user and is unaffected.
|
||||
"""
|
||||
user = getattr(request.state, "user", None)
|
||||
if user is None:
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
user = await get_current_user_from_request(request)
|
||||
|
||||
if getattr(user, "system_role", None) != "admin":
|
||||
raise HTTPException(status_code=403, detail="Admin privileges required to manage channel runtime credentials.")
|
||||
|
||||
|
||||
def _get_app_config():
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
return get_app_config()
|
||||
|
||||
|
||||
async def _get_runtime_config_store(request: Request) -> ChannelRuntimeConfigStore:
|
||||
store = getattr(request.app.state, "channel_runtime_config_store", None)
|
||||
if isinstance(store, ChannelRuntimeConfigStore):
|
||||
return store
|
||||
# Constructing the store reads its JSON file from disk; keep it off the
|
||||
# event loop.
|
||||
store = await asyncio.to_thread(ChannelRuntimeConfigStore)
|
||||
request.app.state.channel_runtime_config_store = store
|
||||
return store
|
||||
|
||||
|
||||
async def _get_channel_connections_config(request: Request) -> ChannelConnectionsConfig:
|
||||
config = getattr(request.app.state, "channel_connections_config", None)
|
||||
if not isinstance(config, ChannelConnectionsConfig):
|
||||
config = _get_app_config().channel_connections
|
||||
config = apply_runtime_connection_config(config, store=await _get_runtime_config_store(request))
|
||||
request.app.state.channel_connections_config = config
|
||||
return config
|
||||
|
||||
|
||||
async def _get_channels_config(request: Request) -> dict[str, Any]:
|
||||
state_config = getattr(request.app.state, "channels_config", None)
|
||||
if isinstance(state_config, dict):
|
||||
return state_config
|
||||
|
||||
result = await _load_channels_config(request, await _get_channel_connections_config(request))
|
||||
request.app.state.channels_config = result
|
||||
return result
|
||||
|
||||
|
||||
async def _load_channels_config(request: Request, config: ChannelConnectionsConfig) -> dict[str, Any]:
|
||||
app_config = _get_app_config()
|
||||
extra = app_config.model_extra or {}
|
||||
channels_config = extra.get("channels")
|
||||
result = dict(channels_config) if isinstance(channels_config, dict) else {}
|
||||
merge_runtime_channel_configs(
|
||||
result,
|
||||
config,
|
||||
store=await _get_runtime_config_store(request),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _get_repository(request: Request, config: ChannelConnectionsConfig) -> ChannelConnectionRepository:
|
||||
repo = getattr(request.app.state, "channel_connection_repo", None)
|
||||
if isinstance(repo, ChannelConnectionRepository):
|
||||
return repo
|
||||
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
raise HTTPException(status_code=503, detail="Channel connection persistence is not available")
|
||||
|
||||
repo = ChannelConnectionRepository(sf)
|
||||
request.app.state.channel_connection_repo = repo
|
||||
return repo
|
||||
|
||||
|
||||
def _provider_config(config: ChannelConnectionsConfig, provider: str):
|
||||
provider_config = getattr(config, provider, None)
|
||||
if provider_config is None:
|
||||
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||
return provider_config
|
||||
|
||||
|
||||
def _runtime_channel_configured(provider: str, channels_config: dict[str, Any]) -> bool:
|
||||
runtime_config = channels_config.get(provider)
|
||||
if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False):
|
||||
return False
|
||||
return all(str(runtime_config.get(key) or "").strip() for key in _RUNTIME_REQUIREMENTS[provider])
|
||||
|
||||
|
||||
def _runtime_unavailable_reason(provider: str) -> str:
|
||||
meta = _PROVIDER_META.get(provider)
|
||||
display_name = meta["display_name"] if meta else provider
|
||||
return f"Enter the required {display_name} credentials to connect this channel."
|
||||
|
||||
|
||||
def _runtime_not_running_reason(provider: str) -> str:
|
||||
meta = _PROVIDER_META.get(provider)
|
||||
display_name = meta["display_name"] if meta else provider
|
||||
return f"{display_name} channel is configured but is not running. Check the credentials and service logs."
|
||||
|
||||
|
||||
def _runtime_channel_running(provider: str) -> bool | None:
|
||||
try:
|
||||
from app.channels.service import get_channel_service
|
||||
except Exception:
|
||||
logger.debug("Unable to inspect channel service status", exc_info=True)
|
||||
return None
|
||||
|
||||
service = get_channel_service()
|
||||
if service is None:
|
||||
return None
|
||||
try:
|
||||
status = service.get_status()
|
||||
except Exception:
|
||||
logger.debug("Unable to read channel service status", exc_info=True)
|
||||
return None
|
||||
|
||||
if not status.get("service_running"):
|
||||
return False
|
||||
channel_status = status.get("channels", {}).get(provider)
|
||||
if not isinstance(channel_status, dict):
|
||||
return None
|
||||
return bool(channel_status.get("running"))
|
||||
|
||||
|
||||
async def _ensure_runtime_channel_ready_if_available(
|
||||
provider: str,
|
||||
channels_config: dict[str, Any],
|
||||
) -> bool | None:
|
||||
runtime_config = channels_config.get(provider)
|
||||
if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False):
|
||||
return None
|
||||
|
||||
try:
|
||||
from app.channels.service import get_channel_service
|
||||
except Exception:
|
||||
logger.debug("Unable to import channel service for readiness reconciliation", exc_info=True)
|
||||
return None
|
||||
|
||||
service = get_channel_service()
|
||||
if service is None:
|
||||
return None
|
||||
|
||||
ensure_channel_ready = getattr(service, "ensure_channel_ready", None)
|
||||
if ensure_channel_ready is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return await ensure_channel_ready(provider, runtime_config)
|
||||
except Exception:
|
||||
logger.exception("Failed to reconcile runtime channel readiness")
|
||||
return False
|
||||
|
||||
|
||||
def _provider_unavailable_reason(
|
||||
config: ChannelConnectionsConfig,
|
||||
channels_config: dict[str, Any],
|
||||
provider: str,
|
||||
) -> str | None:
|
||||
provider_config = _provider_config(config, provider)
|
||||
if not provider_config.enabled:
|
||||
return None
|
||||
if not provider_config.configured:
|
||||
return _runtime_unavailable_reason(provider)
|
||||
if not _runtime_channel_configured(provider, channels_config):
|
||||
return _runtime_unavailable_reason(provider)
|
||||
if _runtime_channel_running(provider) is False:
|
||||
return _runtime_not_running_reason(provider)
|
||||
return None
|
||||
|
||||
|
||||
def _provider_status(
|
||||
config: ChannelConnectionsConfig,
|
||||
channels_config: dict[str, Any],
|
||||
provider: str,
|
||||
) -> tuple[dict[str, bool], str | None]:
|
||||
declared = config.provider_status(provider)
|
||||
unavailable_reason = _provider_unavailable_reason(config, channels_config, provider)
|
||||
configured = declared["configured"] and _runtime_channel_configured(provider, channels_config)
|
||||
return {"enabled": declared["enabled"], "configured": configured}, unavailable_reason
|
||||
|
||||
|
||||
def _new_binding_code() -> str:
|
||||
return secrets.token_urlsafe(16)
|
||||
|
||||
|
||||
async def _create_state(
|
||||
repo: ChannelConnectionRepository,
|
||||
*,
|
||||
owner_user_id: str,
|
||||
provider: str,
|
||||
) -> str:
|
||||
state = _new_binding_code()
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id=owner_user_id,
|
||||
provider=provider,
|
||||
state=state,
|
||||
expires_at=datetime.now(UTC) + timedelta(seconds=_STATE_TTL_SECONDS),
|
||||
)
|
||||
return state
|
||||
|
||||
|
||||
def _connect_instruction(provider: str, code: str) -> str:
|
||||
if provider == "telegram":
|
||||
return f"Send /start {code} to the DeerFlow Telegram bot."
|
||||
meta = _PROVIDER_META.get(provider)
|
||||
if meta is None:
|
||||
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||
return f"Send /connect {code} to the DeerFlow {meta['display_name']} bot."
|
||||
|
||||
|
||||
def _connect_url(config: ChannelConnectionsConfig, provider: str, code: str) -> str | None:
|
||||
if provider == "telegram":
|
||||
provider_config = _provider_config(config, provider)
|
||||
return f"https://t.me/{provider_config.bot_username}?start={code}"
|
||||
if _PROVIDER_META.get(provider, {}).get("auth_mode") == "binding_code":
|
||||
return None
|
||||
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||
|
||||
|
||||
def _connection_updated_at(connection: dict[str, Any]) -> datetime:
|
||||
value = connection.get("updated_at")
|
||||
if isinstance(value, datetime):
|
||||
return value if value.tzinfo is not None else value.replace(tzinfo=UTC)
|
||||
if isinstance(value, str) and value:
|
||||
try:
|
||||
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
pass
|
||||
return datetime.min.replace(tzinfo=UTC)
|
||||
|
||||
|
||||
def _newest_connection_by_provider(connections: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||
by_provider: dict[str, dict[str, Any]] = {}
|
||||
for item in connections:
|
||||
existing = by_provider.get(item["provider"])
|
||||
if existing is None or _connection_updated_at(item) > _connection_updated_at(existing):
|
||||
by_provider[item["provider"]] = item
|
||||
return by_provider
|
||||
|
||||
|
||||
def _credential_fields(provider: str) -> list[ChannelCredentialFieldResponse]:
|
||||
fields = _CREDENTIAL_FIELDS.get(provider)
|
||||
if fields is None:
|
||||
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||
return [ChannelCredentialFieldResponse(**field) for field in fields]
|
||||
|
||||
|
||||
def _credential_values(provider: str, channels_config: dict[str, Any]) -> dict[str, str]:
|
||||
runtime_config = channels_config.get(provider)
|
||||
if not isinstance(runtime_config, dict):
|
||||
return {}
|
||||
|
||||
values: dict[str, str] = {}
|
||||
for field in _credential_fields(provider):
|
||||
value = str(runtime_config.get(field.name) or "").strip()
|
||||
if not value:
|
||||
continue
|
||||
values[field.name] = _MASKED_CREDENTIAL_VALUE if field.type == "password" else value
|
||||
return values
|
||||
|
||||
|
||||
def _provider_response(
|
||||
config: ChannelConnectionsConfig,
|
||||
channels_config: dict[str, Any],
|
||||
provider: str,
|
||||
meta: dict[str, str],
|
||||
connection: dict[str, Any] | None = None,
|
||||
) -> ChannelProviderResponse:
|
||||
from app.gateway.auth_disabled import is_auth_disabled
|
||||
|
||||
status, unavailable_reason = _provider_status(config, channels_config, provider)
|
||||
if connection:
|
||||
connection_status = connection["status"]
|
||||
elif is_auth_disabled() and status["configured"] and unavailable_reason is None:
|
||||
# Auth-disabled local mode routes every channel message to the default
|
||||
# user, so a configured running channel needs no per-user binding.
|
||||
connection_status = "connected"
|
||||
else:
|
||||
connection_status = "not_connected"
|
||||
credential_values = _credential_values(provider, channels_config)
|
||||
if provider == "telegram" and not credential_values.get("bot_username"):
|
||||
bot_username = str(_provider_config(config, provider).bot_username or "").strip()
|
||||
if bot_username:
|
||||
credential_values["bot_username"] = bot_username
|
||||
return ChannelProviderResponse(
|
||||
provider=provider,
|
||||
display_name=meta["display_name"],
|
||||
enabled=status["enabled"],
|
||||
configured=status["configured"],
|
||||
connectable=status["enabled"] and status["configured"] and unavailable_reason is None,
|
||||
unavailable_reason=unavailable_reason,
|
||||
auth_mode=meta["auth_mode"],
|
||||
connection_status=connection_status,
|
||||
credential_fields=_credential_fields(provider),
|
||||
credential_values=credential_values,
|
||||
)
|
||||
|
||||
|
||||
def _required_runtime_values(
|
||||
provider: str,
|
||||
values: dict[str, str],
|
||||
existing_config: dict[str, Any] | None = None,
|
||||
) -> dict[str, str]:
|
||||
fields = _credential_fields(provider)
|
||||
cleaned: dict[str, str] = {}
|
||||
missing: list[str] = []
|
||||
existing_config = existing_config or {}
|
||||
for field in fields:
|
||||
raw_value = values.get(field.name, "")
|
||||
if field.type == "password" and raw_value == _MASKED_CREDENTIAL_VALUE:
|
||||
existing_value = str(existing_config.get(field.name) or "").strip()
|
||||
if existing_value:
|
||||
cleaned[field.name] = existing_value
|
||||
continue
|
||||
value = raw_value.strip() if isinstance(raw_value, str) else str(raw_value or "").strip()
|
||||
if field.required and not value:
|
||||
missing.append(field.label)
|
||||
cleaned[field.name] = value
|
||||
if missing:
|
||||
raise HTTPException(status_code=400, detail=f"Missing required channel configuration: {', '.join(missing)}")
|
||||
return cleaned
|
||||
|
||||
|
||||
async def _restart_runtime_channel_if_available(provider: str, runtime_config: dict[str, Any]) -> bool | None:
|
||||
try:
|
||||
from app.channels.service import get_channel_service
|
||||
except Exception:
|
||||
logger.exception("Failed to import channel service while configuring a runtime channel")
|
||||
return None
|
||||
|
||||
service = get_channel_service()
|
||||
if service is None:
|
||||
return None
|
||||
return await service.configure_channel(provider, runtime_config)
|
||||
|
||||
|
||||
async def _sync_runtime_channel_after_removal(provider: str, channels_config: dict[str, Any]) -> bool | None:
|
||||
try:
|
||||
from app.channels.service import get_channel_service
|
||||
except Exception:
|
||||
logger.exception("Failed to import channel service while disconnecting a runtime channel")
|
||||
return None
|
||||
|
||||
service = get_channel_service()
|
||||
if service is None:
|
||||
return None
|
||||
|
||||
runtime_config = channels_config.get(provider)
|
||||
if isinstance(runtime_config, dict) and runtime_config.get("enabled", False):
|
||||
return await service.configure_channel(provider, runtime_config)
|
||||
return await service.remove_channel(provider)
|
||||
|
||||
|
||||
@router.get("/providers", response_model=ChannelProvidersResponse)
|
||||
async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
||||
config = await _get_channel_connections_config(request)
|
||||
channels_config = await _get_channels_config(request)
|
||||
repo = None
|
||||
if config.enabled:
|
||||
try:
|
||||
repo = _get_repository(request, config)
|
||||
except HTTPException as exc:
|
||||
if exc.status_code != 503:
|
||||
raise
|
||||
owner_user_id = _get_user_id(request)
|
||||
connections = await repo.list_connections(owner_user_id) if repo is not None else []
|
||||
by_provider = _newest_connection_by_provider(connections)
|
||||
|
||||
enabled_providers = [provider for provider in _PROVIDER_META if config.provider_status(provider)["enabled"]]
|
||||
# Readiness reconciliation is independent per provider; run it
|
||||
# concurrently so one slow channel restart does not serialize the
|
||||
# whole /providers response.
|
||||
await asyncio.gather(
|
||||
*(_ensure_runtime_channel_ready_if_available(provider, channels_config) for provider in enabled_providers if _runtime_channel_configured(provider, channels_config)),
|
||||
)
|
||||
|
||||
providers: list[ChannelProviderResponse] = []
|
||||
for provider in enabled_providers:
|
||||
connection = by_provider.get(provider)
|
||||
providers.append(_provider_response(config, channels_config, provider, _PROVIDER_META[provider], connection))
|
||||
return ChannelProvidersResponse(enabled=config.enabled, providers=providers)
|
||||
|
||||
|
||||
@router.get("/connections", response_model=ChannelConnectionsResponse)
|
||||
async def get_channel_connections(request: Request) -> ChannelConnectionsResponse:
|
||||
config = await _get_channel_connections_config(request)
|
||||
if not config.enabled:
|
||||
return ChannelConnectionsResponse(connections=[])
|
||||
repo = _get_repository(request, config)
|
||||
rows = await repo.list_connections(_get_user_id(request))
|
||||
return ChannelConnectionsResponse(connections=[ChannelConnectionResponse(**row) for row in rows])
|
||||
|
||||
|
||||
@router.delete("/connections/{connection_id}", status_code=204)
|
||||
async def disconnect_channel_connection(connection_id: str, request: Request) -> Response:
|
||||
config = await _get_channel_connections_config(request)
|
||||
if not config.enabled:
|
||||
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||
|
||||
repo = _get_repository(request, config)
|
||||
disconnected = await repo.disconnect_connection(
|
||||
connection_id=connection_id,
|
||||
owner_user_id=_get_user_id(request),
|
||||
)
|
||||
if not disconnected:
|
||||
raise HTTPException(status_code=404, detail="Channel connection not found")
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.delete("/{provider}/runtime-config", response_model=ChannelProviderResponse)
|
||||
async def disconnect_channel_provider_runtime(provider: str, request: Request) -> ChannelProviderResponse:
|
||||
await _require_admin_user(request)
|
||||
config = await _get_channel_connections_config(request)
|
||||
if not config.enabled:
|
||||
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||
|
||||
provider_config = _provider_config(config, provider)
|
||||
if not provider_config.enabled:
|
||||
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
||||
|
||||
owner_user_id = _get_user_id(request)
|
||||
try:
|
||||
repo = _get_repository(request, config)
|
||||
except HTTPException as exc:
|
||||
if exc.status_code != 503:
|
||||
raise
|
||||
repo = None
|
||||
|
||||
if repo is not None:
|
||||
for connection in await repo.list_connections(owner_user_id):
|
||||
if connection["provider"] == provider and connection["status"] != "revoked":
|
||||
await repo.disconnect_connection(
|
||||
connection_id=connection["id"],
|
||||
owner_user_id=owner_user_id,
|
||||
)
|
||||
|
||||
store = await _get_runtime_config_store(request)
|
||||
await asyncio.to_thread(store.set_provider_disconnected, provider)
|
||||
channels_config = await _load_channels_config(request, config)
|
||||
request.app.state.channels_config = channels_config
|
||||
|
||||
stopped = await _sync_runtime_channel_after_removal(provider, channels_config)
|
||||
if stopped is False:
|
||||
display_name = _PROVIDER_META[provider]["display_name"]
|
||||
raise HTTPException(status_code=400, detail=f"Failed to stop {display_name} channel. Try again.")
|
||||
|
||||
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
|
||||
|
||||
|
||||
@router.post("/{provider}/connect", response_model=ChannelConnectResponse)
|
||||
async def connect_channel_provider(provider: str, request: Request) -> ChannelConnectResponse:
|
||||
config = await _get_channel_connections_config(request)
|
||||
channels_config = await _get_channels_config(request)
|
||||
if not config.enabled:
|
||||
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||
|
||||
provider_config = _provider_config(config, provider)
|
||||
if provider_config.enabled and _runtime_channel_configured(provider, channels_config):
|
||||
await _ensure_runtime_channel_ready_if_available(provider, channels_config)
|
||||
|
||||
status, unavailable_reason = _provider_status(config, channels_config, provider)
|
||||
if not status["enabled"]:
|
||||
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
||||
if unavailable_reason:
|
||||
raise HTTPException(status_code=400, detail=unavailable_reason)
|
||||
if not status["configured"]:
|
||||
raise HTTPException(status_code=400, detail="Channel provider is not configured")
|
||||
|
||||
repo = _get_repository(request, config)
|
||||
code = await _create_state(
|
||||
repo,
|
||||
owner_user_id=_get_user_id(request),
|
||||
provider=provider,
|
||||
)
|
||||
return ChannelConnectResponse(
|
||||
provider=provider,
|
||||
mode=_PROVIDER_META[provider]["auth_mode"],
|
||||
url=_connect_url(config, provider, code),
|
||||
code=code,
|
||||
instruction=_connect_instruction(provider, code),
|
||||
expires_in=_STATE_TTL_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{provider}/runtime-config", response_model=ChannelProviderResponse)
|
||||
async def configure_channel_provider_runtime(
|
||||
provider: str,
|
||||
body: ChannelRuntimeConfigRequest,
|
||||
request: Request,
|
||||
) -> ChannelProviderResponse:
|
||||
await _require_admin_user(request)
|
||||
config = await _get_channel_connections_config(request)
|
||||
if not config.enabled:
|
||||
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||
|
||||
provider_config = _provider_config(config, provider)
|
||||
if not provider_config.enabled:
|
||||
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
||||
|
||||
channels_config = await _get_channels_config(request)
|
||||
existing = channels_config.get(provider)
|
||||
runtime_config = dict(existing) if isinstance(existing, dict) else {}
|
||||
values = _required_runtime_values(provider, body.values, runtime_config)
|
||||
runtime_config["enabled"] = True
|
||||
|
||||
for key in _RUNTIME_REQUIREMENTS[provider]:
|
||||
runtime_config[key] = values[key]
|
||||
|
||||
if provider == "telegram":
|
||||
# The deep-link username is persisted with the runtime channel config
|
||||
# (set_provider_config below) and applied to future requests via
|
||||
# apply_runtime_connection_config; never mutate the config instance
|
||||
# cached by get_app_config().
|
||||
runtime_config["bot_username"] = values["bot_username"]
|
||||
|
||||
channels_config[provider] = runtime_config
|
||||
request.app.state.channels_config = channels_config
|
||||
|
||||
started = await _restart_runtime_channel_if_available(provider, runtime_config)
|
||||
if started is False:
|
||||
display_name = _PROVIDER_META[provider]["display_name"]
|
||||
raise HTTPException(status_code=400, detail=f"Failed to start {display_name} channel. Check the values and try again.")
|
||||
|
||||
store = await _get_runtime_config_store(request)
|
||||
await asyncio.to_thread(store.set_provider_config, provider, runtime_config)
|
||||
|
||||
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
|
||||
@@ -98,6 +98,7 @@ class MemoryConfigResponse(BaseModel):
|
||||
fact_confidence_threshold: float = Field(..., description="Minimum confidence threshold for facts")
|
||||
injection_enabled: bool = Field(..., description="Whether memory injection is enabled")
|
||||
max_injection_tokens: int = Field(..., description="Maximum tokens for memory injection")
|
||||
token_counting: str = Field(..., description="Token counting strategy for memory injection ('tiktoken' or 'char')")
|
||||
|
||||
|
||||
class MemoryStatusResponse(BaseModel):
|
||||
@@ -310,7 +311,8 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
||||
"max_facts": 100,
|
||||
"fact_confidence_threshold": 0.7,
|
||||
"injection_enabled": true,
|
||||
"max_injection_tokens": 2000
|
||||
"max_injection_tokens": 2000,
|
||||
"token_counting": "tiktoken"
|
||||
}
|
||||
```
|
||||
"""
|
||||
@@ -323,6 +325,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
||||
fact_confidence_threshold=config.fact_confidence_threshold,
|
||||
injection_enabled=config.injection_enabled,
|
||||
max_injection_tokens=config.max_injection_tokens,
|
||||
token_counting=config.token_counting,
|
||||
)
|
||||
|
||||
|
||||
@@ -351,6 +354,7 @@ async def get_memory_status() -> MemoryStatusResponse:
|
||||
fact_confidence_threshold=config.fact_confidence_threshold,
|
||||
injection_enabled=config.injection_enabled,
|
||||
max_injection_tokens=config.max_injection_tokens,
|
||||
token_counting=config.token_counting,
|
||||
),
|
||||
data=MemoryResponse(**memory_data),
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_
|
||||
from app.gateway.pagination import trim_run_message_page
|
||||
from app.gateway.routers.thread_runs import RunCreateRequest
|
||||
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
from deerflow.runtime import serialize_channel_values_for_api
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/runs", tags=["runs"])
|
||||
@@ -82,7 +82,7 @@ async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
|
||||
if checkpoint_tuple is not None:
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
return serialize_channel_values(channel_values)
|
||||
return serialize_channel_values_for_api(channel_values)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.pagination import trim_run_message_page
|
||||
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
|
||||
from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values
|
||||
from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values_for_api
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/threads", tags=["runs"])
|
||||
@@ -192,7 +192,7 @@ async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) ->
|
||||
if checkpoint_tuple is not None:
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
return serialize_channel_values(channel_values)
|
||||
return serialize_channel_values_for_api(channel_values)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
|
||||
|
||||
@@ -22,9 +22,10 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer
|
||||
from app.gateway.internal_auth import get_trusted_internal_owner_user_id
|
||||
from app.gateway.utils import sanitize_log_param
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
from deerflow.runtime import serialize_channel_values_for_api
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.utils.time import coerce_iso, now_iso
|
||||
|
||||
@@ -257,11 +258,19 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
thread_store = get_thread_store(request)
|
||||
thread_id = body.thread_id or str(uuid.uuid4())
|
||||
now = now_iso()
|
||||
thread_owner_user_id = get_trusted_internal_owner_user_id(request)
|
||||
thread_owner_kwargs = {"user_id": thread_owner_user_id} if thread_owner_user_id else {}
|
||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
||||
|
||||
# Idempotency: return existing record when already present
|
||||
existing_record = await thread_store.get(thread_id)
|
||||
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
|
||||
if existing_record is None and thread_owner_user_id:
|
||||
unscoped_record = await thread_store.get(thread_id, user_id=None)
|
||||
if unscoped_record is not None:
|
||||
if unscoped_record.get("user_id") != thread_owner_user_id:
|
||||
await thread_store.update_owner(thread_id, thread_owner_user_id, user_id=None)
|
||||
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
|
||||
if existing_record is not None:
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
@@ -276,6 +285,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
await thread_store.create(
|
||||
thread_id,
|
||||
assistant_id=getattr(body, "assistant_id", None),
|
||||
**thread_owner_kwargs,
|
||||
metadata=body.metadata,
|
||||
)
|
||||
except Exception:
|
||||
@@ -427,7 +437,7 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
created_at=coerce_iso(record.get("created_at", "")),
|
||||
updated_at=coerce_iso(record.get("updated_at", "")),
|
||||
metadata=record.get("metadata", {}),
|
||||
values=serialize_channel_values(channel_values),
|
||||
values=serialize_channel_values_for_api(channel_values),
|
||||
)
|
||||
|
||||
|
||||
@@ -470,7 +480,7 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
|
||||
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
||||
|
||||
values = serialize_channel_values(channel_values)
|
||||
values = serialize_channel_values_for_api(channel_values)
|
||||
|
||||
return ThreadStateResponse(
|
||||
values=values,
|
||||
@@ -578,7 +588,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||
|
||||
return ThreadStateResponse(
|
||||
values=serialize_channel_values(channel_values),
|
||||
values=serialize_channel_values_for_api(channel_values),
|
||||
next=[],
|
||||
metadata=metadata,
|
||||
checkpoint_id=new_checkpoint_id,
|
||||
@@ -630,7 +640,7 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
||||
if is_latest_checkpoint:
|
||||
messages = channel_values.get("messages")
|
||||
if messages:
|
||||
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
||||
values["messages"] = serialize_channel_values_for_api({"messages": messages}).get("messages", [])
|
||||
is_latest_checkpoint = False
|
||||
|
||||
# Derive next tasks
|
||||
|
||||
+115
-70
@@ -12,6 +12,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
@@ -19,7 +20,7 @@ from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages.utils import convert_to_messages
|
||||
|
||||
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
|
||||
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
|
||||
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE, get_trusted_internal_owner_user_id
|
||||
from app.gateway.utils import sanitize_log_param
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.runtime import (
|
||||
@@ -35,6 +36,7 @@ from deerflow.runtime import (
|
||||
run_agent,
|
||||
)
|
||||
from deerflow.runtime.runs.naming import resolve_root_run_name
|
||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -209,11 +211,14 @@ def build_run_config(
|
||||
|
||||
When *assistant_id* refers to a custom agent (anything other than
|
||||
``"lead_agent"`` / ``None``), the name is forwarded as ``agent_name`` in
|
||||
whichever runtime options container is active: ``context`` for
|
||||
LangGraph >= 0.6.0 requests, otherwise ``configurable``.
|
||||
``make_lead_agent`` reads this key to load the matching
|
||||
``agents/<name>/SOUL.md`` and per-agent config — without it the agent
|
||||
silently runs as the default lead agent.
|
||||
both ``configurable`` and ``context`` so it is visible to legacy
|
||||
configurable readers and to LangGraph ``ToolRuntime.context`` consumers
|
||||
(e.g. the ``setup_agent`` tool, which since LangGraph >=1.1.9 no longer
|
||||
falls back from ``context`` to ``configurable``). An explicit
|
||||
``agent_name`` in either container takes precedence over the value
|
||||
derived from ``assistant_id``. ``make_lead_agent`` reads this key to
|
||||
load the matching ``agents/<name>/SOUL.md`` and per-agent config —
|
||||
without it the agent silently runs as the default lead agent.
|
||||
|
||||
This mirrors the channel manager's ``_resolve_run_params`` logic so that
|
||||
the LangGraph Platform-compatible HTTP API and the IM channel path behave
|
||||
@@ -251,19 +256,23 @@ def build_run_config(
|
||||
config["configurable"] = {"thread_id": thread_id}
|
||||
|
||||
# Inject custom agent name when the caller specified a non-default assistant.
|
||||
# Honour an explicit agent_name in the active runtime options container.
|
||||
# Honour an explicit agent_name in either runtime options container.
|
||||
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID:
|
||||
normalized = assistant_id.strip().lower().replace("_", "-")
|
||||
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
||||
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
|
||||
if "configurable" in config:
|
||||
target = config["configurable"]
|
||||
elif "context" in config:
|
||||
target = config["context"]
|
||||
else:
|
||||
target = config.setdefault("configurable", {})
|
||||
if target is not None and "agent_name" not in target:
|
||||
target["agent_name"] = normalized
|
||||
configurable = config.setdefault("configurable", {})
|
||||
runtime_context = config.setdefault("context", {})
|
||||
explicit_agent_name: str | None = None
|
||||
if isinstance(configurable, dict) and isinstance(configurable.get("agent_name"), str):
|
||||
explicit_agent_name = configurable["agent_name"]
|
||||
elif isinstance(runtime_context, dict) and isinstance(runtime_context.get("agent_name"), str):
|
||||
explicit_agent_name = runtime_context["agent_name"]
|
||||
effective_agent_name = explicit_agent_name or normalized
|
||||
if isinstance(configurable, dict):
|
||||
configurable["agent_name"] = effective_agent_name
|
||||
if isinstance(runtime_context, dict):
|
||||
runtime_context["agent_name"] = effective_agent_name
|
||||
config.setdefault("run_name", resolve_root_run_name(config, normalized))
|
||||
if metadata:
|
||||
config.setdefault("metadata", {}).update(metadata)
|
||||
@@ -315,72 +324,108 @@ async def start_run(
|
||||
detail=f"Model {model_name!r} is not in the configured model allowlist",
|
||||
)
|
||||
|
||||
try:
|
||||
record = await run_mgr.create_or_reject(
|
||||
thread_id,
|
||||
body.assistant_id,
|
||||
on_disconnect=disconnect,
|
||||
metadata=body.metadata or {},
|
||||
kwargs={"input": body.input, "config": body.config},
|
||||
multitask_strategy=body.multitask_strategy,
|
||||
model_name=model_name,
|
||||
)
|
||||
except ConflictError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||
except UnsupportedStrategyError as exc:
|
||||
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
||||
owner_user_id = get_trusted_internal_owner_user_id(request)
|
||||
# Stateless run endpoints carry thread_id in the request *body*, so the
|
||||
# @require_permission(owner_check=True) decorator -- which resolves ownership
|
||||
# from the path param -- cannot protect them. Enforce thread ownership here,
|
||||
# before any run is created, so one user cannot start runs on (or read /wait
|
||||
# checkpoint state from) another user's thread. Missing rows (auto-created
|
||||
# temp threads) and NULL-owner rows (shared / pre-auth data) stay accessible
|
||||
# via check_access; only a thread already owned by another user is rejected
|
||||
# with 404, matching thread_runs.py's anti-enumeration behaviour. Internal
|
||||
# channel runs act on behalf of the connection owner carried in
|
||||
# X-DeerFlow-Owner-User-Id, so they are scoped to that owner instead of
|
||||
# bypassing the check -- a leaked internal token must not grant cross-user
|
||||
# thread access.
|
||||
user = getattr(request.state, "user", None)
|
||||
if user is not None:
|
||||
allowed = await run_ctx.thread_store.check_access(thread_id, str(user.id))
|
||||
if not allowed and owner_user_id and getattr(user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
|
||||
# Channel workers may also act for the connection owner named in
|
||||
# the trusted header (e.g. claiming a legacy default-owned channel
|
||||
# thread for its real owner).
|
||||
allowed = await run_ctx.thread_store.check_access(thread_id, owner_user_id)
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
# Upsert thread metadata so the thread appears in /threads/search,
|
||||
# even for threads that were never explicitly created via POST /threads
|
||||
# (e.g. stateless runs).
|
||||
owner_context_token = set_current_user(SimpleNamespace(id=owner_user_id)) if owner_user_id else None
|
||||
try:
|
||||
existing = await run_ctx.thread_store.get(thread_id)
|
||||
if existing is None:
|
||||
await run_ctx.thread_store.create(
|
||||
try:
|
||||
record = await run_mgr.create_or_reject(
|
||||
thread_id,
|
||||
assistant_id=body.assistant_id,
|
||||
metadata=body.metadata,
|
||||
body.assistant_id,
|
||||
on_disconnect=disconnect,
|
||||
metadata=body.metadata or {},
|
||||
kwargs={"input": body.input, "config": body.config},
|
||||
multitask_strategy=body.multitask_strategy,
|
||||
model_name=model_name,
|
||||
user_id=owner_user_id,
|
||||
)
|
||||
else:
|
||||
await run_ctx.thread_store.update_status(thread_id, "running")
|
||||
except Exception:
|
||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||
except ConflictError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||
except UnsupportedStrategyError as exc:
|
||||
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
||||
|
||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
||||
graph_input = normalize_input(body.input)
|
||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||
# Upsert thread metadata so the thread appears in /threads/search,
|
||||
# even for threads that were never explicitly created via POST /threads
|
||||
# (e.g. stateless runs).
|
||||
try:
|
||||
existing = await run_ctx.thread_store.get(thread_id)
|
||||
if existing is None and owner_user_id:
|
||||
unscoped_existing = await run_ctx.thread_store.get(thread_id, user_id=None)
|
||||
if unscoped_existing is not None:
|
||||
if unscoped_existing.get("user_id") != owner_user_id:
|
||||
await run_ctx.thread_store.update_owner(thread_id, owner_user_id, user_id=None)
|
||||
existing = await run_ctx.thread_store.get(thread_id)
|
||||
if existing is None:
|
||||
await run_ctx.thread_store.create(
|
||||
thread_id,
|
||||
assistant_id=body.assistant_id,
|
||||
metadata=body.metadata,
|
||||
)
|
||||
else:
|
||||
await run_ctx.thread_store.update_status(thread_id, "running")
|
||||
except Exception:
|
||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||
|
||||
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
|
||||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||||
merge_run_context_overrides(config, getattr(body, "context", None))
|
||||
inject_authenticated_user_context(config, request)
|
||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
||||
graph_input = normalize_input(body.input)
|
||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||
|
||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
||||
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
|
||||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||||
merge_run_context_overrides(config, getattr(body, "context", None))
|
||||
inject_authenticated_user_context(config, request)
|
||||
|
||||
task = asyncio.create_task(
|
||||
run_agent(
|
||||
bridge,
|
||||
run_mgr,
|
||||
record,
|
||||
ctx=run_ctx,
|
||||
agent_factory=agent_factory,
|
||||
graph_input=graph_input,
|
||||
config=config,
|
||||
stream_modes=stream_modes,
|
||||
stream_subgraphs=body.stream_subgraphs,
|
||||
interrupt_before=body.interrupt_before,
|
||||
interrupt_after=body.interrupt_after,
|
||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
||||
|
||||
task = asyncio.create_task(
|
||||
run_agent(
|
||||
bridge,
|
||||
run_mgr,
|
||||
record,
|
||||
ctx=run_ctx,
|
||||
agent_factory=agent_factory,
|
||||
graph_input=graph_input,
|
||||
config=config,
|
||||
stream_modes=stream_modes,
|
||||
stream_subgraphs=body.stream_subgraphs,
|
||||
interrupt_before=body.interrupt_before,
|
||||
interrupt_after=body.interrupt_after,
|
||||
)
|
||||
)
|
||||
)
|
||||
record.task = task
|
||||
record.task = task
|
||||
|
||||
# Title sync is handled by worker.py's finally block which reads the
|
||||
# title from the checkpoint and calls thread_store.update_display_name
|
||||
# after the run completes.
|
||||
# Title sync is handled by worker.py's finally block which reads the
|
||||
# title from the checkpoint and calls thread_store.update_display_name
|
||||
# after the run completes.
|
||||
|
||||
return record
|
||||
return record
|
||||
finally:
|
||||
if owner_context_token is not None:
|
||||
reset_current_user(owner_context_token)
|
||||
|
||||
|
||||
async def sse_consumer(
|
||||
|
||||
@@ -67,6 +67,11 @@ The normal workflow is:
|
||||
3. Add or update a focused runtime anchor in `backend/tests/blocking_io/`.
|
||||
4. Let CI prevent that path from regressing.
|
||||
|
||||
Contributors changing backend async code can run the `blocking-io-guard` skill
|
||||
(`.agent/skills/blocking-io-guard/`) to execute steps 1–3 for their own diff: it
|
||||
scans the change for blocking-IO candidates, drafts or extends a runtime anchor,
|
||||
and verifies the anchor fails when the blocking IO regresses.
|
||||
|
||||
Runtime detection has two maintenance paths.
|
||||
|
||||
### Add a runtime rule
|
||||
|
||||
@@ -234,7 +234,7 @@ tools:
|
||||
```
|
||||
|
||||
**Built-in Tools**:
|
||||
- `web_search` - Search the web (DuckDuckGo, Tavily, Exa, InfoQuest, Firecrawl)
|
||||
- `web_search` - Search the web (DuckDuckGo, Tavily, Brave, Exa, InfoQuest, Firecrawl)
|
||||
- `web_fetch` - Fetch web pages (Jina AI, Exa, InfoQuest, Firecrawl)
|
||||
- `ls` - List directory contents
|
||||
- `read_file` - Read file contents
|
||||
@@ -303,6 +303,55 @@ When you configure `sandbox.mounts`, DeerFlow exposes those `container_path` val
|
||||
|
||||
For bare-metal Docker sandbox runs that use localhost, DeerFlow binds the sandbox HTTP port to `127.0.0.1` by default so it is not exposed on every host interface. Docker-outside-of-Docker deployments that connect through `host.docker.internal` keep the broad legacy bind for compatibility. Set `DEER_FLOW_SANDBOX_BIND_HOST` explicitly if your deployment needs a different bind address.
|
||||
|
||||
### Building a Custom AIO Sandbox Image
|
||||
|
||||
`AioSandboxProvider` talks to the sandbox container through the `agent-sandbox` SDK. The Dockerfile for the default `enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest` image is not part of this repository; DeerFlow treats that image as an upstream AIO sandbox runtime.
|
||||
|
||||
For persistent system or language dependencies, extend the published image and keep its startup command intact:
|
||||
|
||||
```dockerfile
|
||||
FROM enterprise-public-cn-beijing.cr.volces.com/vefaas-public/all-in-one-sandbox:latest
|
||||
|
||||
USER root
|
||||
# Example user dependency; not required by DeerFlow itself.
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends graphviz \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Example Python dependency for work done inside the sandbox.
|
||||
RUN python -m pip install --no-cache-dir pandas
|
||||
|
||||
# Do not override ENTRYPOINT or CMD; keep the upstream sandbox server startup.
|
||||
```
|
||||
|
||||
Use the custom image in local Docker or Apple Container mode with `sandbox.image`:
|
||||
|
||||
```yaml
|
||||
sandbox:
|
||||
use: deerflow.community.aio_sandbox:AioSandboxProvider
|
||||
image: your-registry/your-aio-sandbox:tag
|
||||
```
|
||||
|
||||
In provisioner mode, sandbox Pods are created by the provisioner service, so configure the provisioner `SANDBOX_IMAGE` environment variable instead of `sandbox.image`. See the [Provisioner Setup Guide](../../docker/provisioner/README.md#custom-sandbox-image).
|
||||
|
||||
If you rebuild the runtime from scratch instead of extending the published image, it must expose the same HTTP API used by `agent-sandbox`. DeerFlow currently depends on:
|
||||
|
||||
- `sandbox.get_context()`, including `home_dir`
|
||||
- `shell.exec_command(...)`
|
||||
- `file.read_file(...)`
|
||||
- `file.write_file(...)`, including base64 writes for binary content
|
||||
- streamed `file.download_file(...)`
|
||||
- `file.find_files(...)`
|
||||
- `file.list_path(...)`
|
||||
- `file.search_in_file(...)`
|
||||
|
||||
Custom images must also keep these compatibility constraints:
|
||||
|
||||
- The container should listen on the configured sandbox port, `8080` by default.
|
||||
- `/mnt/user-data` must remain writable because DeerFlow mounts thread workspace, uploads, and outputs there.
|
||||
- `home_dir` comes from the sandbox context endpoint; do not assume DeerFlow hardcodes it.
|
||||
- Shell command handling must remain compatible with serialized `exec_command` calls. DeerFlow serializes shell access on the host side to avoid corrupting the sandbox's persistent shell session.
|
||||
|
||||
### Skills
|
||||
|
||||
Configure the skills directory for specialized workflows:
|
||||
@@ -364,6 +413,7 @@ models:
|
||||
- `MIMO_API_KEY` - Xiaomi MiMo API key
|
||||
- `NOVITA_API_KEY` - Novita API key (OpenAI-compatible endpoint)
|
||||
- `TAVILY_API_KEY` - Tavily search API key
|
||||
- `BRAVE_SEARCH_API_KEY` - Brave Search API key
|
||||
- `DEER_FLOW_PROJECT_ROOT` - Project root for relative runtime paths
|
||||
- `DEER_FLOW_CONFIG_PATH` - Custom config file path
|
||||
- `DEER_FLOW_EXTENSIONS_CONFIG_PATH` - Custom extensions config file path
|
||||
@@ -384,6 +434,76 @@ DeerFlow searches for configuration in this order:
|
||||
3. `config.yaml` under `DEER_FLOW_PROJECT_ROOT`, or under the current working directory when `DEER_FLOW_PROJECT_ROOT` is unset
|
||||
4. Legacy backend/repository-root locations for monorepo compatibility
|
||||
|
||||
## Security Notes
|
||||
### Sandbox Isolation and the Docker Socket (DooD)
|
||||
|
||||
DeerFlow executes agent-generated shell/code through a configurable sandbox
|
||||
(`sandbox.use` in `config.yaml`). The isolation guarantees differ by mode, and
|
||||
one mode requires mounting the host Docker socket. Understand the trade-offs
|
||||
before exposing an instance to untrusted input.
|
||||
|
||||
| Mode | `config.yaml` | Host Docker socket | Isolation |
|
||||
|------|---------------|--------------------|-----------|
|
||||
| `local` (default) | `deerflow.sandbox.local:LocalSandboxProvider` | Not mounted | Commands run **inside the gateway container** on its filesystem. Not a strong boundary — `allow_host_bash` is `false` by default and should stay off for untrusted workloads. |
|
||||
| `aio` (pure DooD) | `deerflow.community.aio_sandbox:AioSandboxProvider` (no `provisioner_url`) | **Mounted** (opt-in overlay) | Sandbox containers are started via the host Docker daemon. |
|
||||
| `provisioner` (Kubernetes) | `AioSandboxProvider` + `provisioner_url` | Not mounted | Sandbox pods are created through the provisioner's K8s API over HTTP. Strongest isolation. |
|
||||
|
||||
#### The Docker socket is host root
|
||||
|
||||
Mounting `/var/run/docker.sock` into a container grants that container
|
||||
**root-equivalent control of the host**: anything able to reach the socket can
|
||||
start a new container that bind-mounts the host filesystem and escape. This
|
||||
matters for DeerFlow because the gateway executes model-generated commands, so a
|
||||
prompt injection or any in-container code-execution primitive could pivot to the
|
||||
host through the socket.
|
||||
|
||||
To keep this off the default attack surface:
|
||||
|
||||
- The host Docker socket is **not** mounted by the default Compose stack. It is
|
||||
added only for `aio` mode through the opt-in `docker/docker-compose.dood.yaml`
|
||||
overlay, which `scripts/deploy.sh` and `scripts/docker.sh` append
|
||||
automatically when `detect_sandbox_mode()` returns `aio`.
|
||||
- Prefer **provisioner/Kubernetes mode** for multi-tenant or internet-exposed
|
||||
deployments — it isolates sandboxes without handing the gateway the host
|
||||
daemon.
|
||||
- If you must use `aio`/DooD, treat the host as part of the gateway's trust
|
||||
boundary: run it on a dedicated host, and consider a scoped Docker API proxy
|
||||
instead of the raw socket.
|
||||
|
||||
> Note: the gateway bind-mounts `$HOME/.claude` and `$HOME/.codex` (read-only)
|
||||
> for CLI auto-auth in **all** modes. These hold long-lived CLI credentials;
|
||||
> scope or omit them when the gateway runs untrusted workloads.
|
||||
|
||||
### CLI Credential Mounts (Claude Code / Codex)
|
||||
|
||||
DeerFlow can reuse your Claude Code / Codex CLI subscription login as a model
|
||||
provider (`ClaudeChatModel`, the Codex provider) or for ACP agents that run the
|
||||
CLI in-container. The Compose stack used to bind-mount the **entire** `~/.claude`
|
||||
and `~/.codex` directories (read-only) into the gateway container in **every**
|
||||
configuration — exposing not just credentials but full conversation history,
|
||||
per-project session data, and global CLI config. A gateway compromise (prompt
|
||||
injection, tool/MCP misuse, RCE) would leak all of it.
|
||||
|
||||
These directories are **no longer mounted by default**. Supply CLI credentials
|
||||
with the least exposure that fits your setup:
|
||||
|
||||
| Need | How | Exposure |
|
||||
|------|-----|----------|
|
||||
| Claude model provider | env `CLAUDE_CODE_OAUTH_TOKEN` / `ANTHROPIC_AUTH_TOKEN` (via `.env`), or `CLAUDE_CODE_CREDENTIALS_PATH` → a single mounted `.credentials.json` | none / one file |
|
||||
| Codex model provider | env `CODEX_AUTH_PATH` pointing at a single mounted `auth.json` | one file |
|
||||
| ACP agent | the adapter's own auth — many ACP adapters take an env API key (e.g. `ANTHROPIC_API_KEY` / `OPENAI_API_KEY`) and need no mount; use the opt-in `docker/docker-compose.cli-auth.yaml` overlay only if your adapter reads the full CLI config dir | none / full dir |
|
||||
|
||||
The Gateway credential loader checks environment variables **before** the
|
||||
default credential files, so the env-token paths need no bind mount at all. ACP
|
||||
adapters authenticate independently of DeerFlow via their own documented env —
|
||||
for example the common `claude-code-acp` adapter starts as
|
||||
`ANTHROPIC_API_KEY=… claude-code-acp` and honors `CLAUDE_CONFIG_DIR` to redirect
|
||||
its config directory, so it needs no `~/.claude` mount at all. Prefer the
|
||||
adapter's documented env auth, and reach for the
|
||||
`docker-compose.cli-auth.yaml` overlay only as a fallback for an adapter that
|
||||
genuinely reads the full CLI config directory.
|
||||
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Place `config.yaml` in project root** - Set `DEER_FLOW_PROJECT_ROOT` if the runtime starts elsewhere
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
# IM Channel Connections
|
||||
|
||||
DeerFlow supports user-owned IM channel bindings for Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom. The feature reuses the existing `channels.*` runtime configuration, so it works in local and private deployments with the same outbound transports already supported by DeerFlow.
|
||||
|
||||
No public IP, OAuth callback URL, or provider webhook is required in this implementation.
|
||||
|
||||
## Configuration
|
||||
|
||||
Configure the actual IM bots under the existing `channels` block:
|
||||
|
||||
```yaml
|
||||
channels:
|
||||
telegram:
|
||||
enabled: true
|
||||
bot_token: $TELEGRAM_BOT_TOKEN
|
||||
|
||||
slack:
|
||||
enabled: true
|
||||
bot_token: $SLACK_BOT_TOKEN
|
||||
app_token: $SLACK_APP_TOKEN
|
||||
|
||||
discord:
|
||||
enabled: true
|
||||
bot_token: $DISCORD_BOT_TOKEN
|
||||
|
||||
feishu:
|
||||
enabled: true
|
||||
app_id: $FEISHU_APP_ID
|
||||
app_secret: $FEISHU_APP_SECRET
|
||||
|
||||
dingtalk:
|
||||
enabled: true
|
||||
client_id: $DINGTALK_CLIENT_ID
|
||||
client_secret: $DINGTALK_CLIENT_SECRET
|
||||
|
||||
wechat:
|
||||
enabled: true
|
||||
bot_token: $WECHAT_BOT_TOKEN
|
||||
|
||||
wecom:
|
||||
enabled: true
|
||||
bot_id: $WECOM_BOT_ID
|
||||
bot_secret: $WECOM_BOT_SECRET
|
||||
```
|
||||
|
||||
Then enable user bindings in `channel_connections`:
|
||||
|
||||
```yaml
|
||||
channel_connections:
|
||||
enabled: true
|
||||
|
||||
telegram:
|
||||
enabled: true
|
||||
bot_username: $TELEGRAM_BOT_USERNAME
|
||||
|
||||
slack:
|
||||
enabled: true
|
||||
|
||||
discord:
|
||||
enabled: true
|
||||
|
||||
feishu:
|
||||
enabled: true
|
||||
|
||||
dingtalk:
|
||||
enabled: true
|
||||
|
||||
wechat:
|
||||
enabled: true
|
||||
|
||||
wecom:
|
||||
enabled: true
|
||||
```
|
||||
|
||||
`channel_connections` does not duplicate provider secrets. It only controls the browser-facing connect UI and stores per-user binding records. Telegram needs `bot_username` only so the frontend can open a deep link.
|
||||
|
||||
## Connect Flow
|
||||
|
||||
Telegram:
|
||||
|
||||
- The frontend creates a short one-time code.
|
||||
- The Connect button opens `https://t.me/<bot_username>?start=<code>`.
|
||||
- The existing Telegram long-polling worker receives `/start <code>` and binds that Telegram chat/user to the current DeerFlow user.
|
||||
|
||||
Slack:
|
||||
|
||||
- The frontend creates a short one-time code.
|
||||
- The UI shows `Send /connect <code> to the DeerFlow Slack bot.`
|
||||
- The existing Slack Socket Mode worker receives the message and binds the Slack user/team to the current DeerFlow user.
|
||||
|
||||
Discord:
|
||||
|
||||
- The frontend creates a short one-time code.
|
||||
- The UI shows `Send /connect <code> to the DeerFlow Discord bot.`
|
||||
- The existing Discord Gateway worker receives the message and binds the Discord user/guild to the current DeerFlow user.
|
||||
|
||||
Feishu/Lark, DingTalk, WeChat, and WeCom:
|
||||
|
||||
- The frontend creates a short one-time code.
|
||||
- The UI shows `Send /connect <code> to the DeerFlow <Provider> bot.`
|
||||
- The already-running long-connection or polling worker receives the message and binds the platform user/workspace identity to the current DeerFlow user.
|
||||
|
||||
Codes use 128 bits of randomness, expire after 10 minutes, and are single-use.
|
||||
|
||||
## Runtime Model
|
||||
|
||||
Connection records live in SQL tables under `deerflow.persistence.channel_connections`:
|
||||
|
||||
- `channel_connections`: owner user, provider identity, workspace/guild/team, status, metadata.
|
||||
- `channel_oauth_states`: one-time connect codes and Telegram deep-link state.
|
||||
- `channel_conversations`: connection-scoped IM conversation to DeerFlow thread mapping.
|
||||
- `channel_credentials`: reserved for future provider-token flows, not used by the local/private binding flow.
|
||||
|
||||
Incoming messages that resolve to a connection carry `connection_id`, `owner_user_id`, and `workspace_id`. `ChannelManager` uses `owner_user_id` as the DeerFlow run user id and preserves the raw platform user id as `channel_user_id`.
|
||||
|
||||
## Security Notes
|
||||
|
||||
- Browser APIs remain authenticated and CSRF-protected.
|
||||
- Connect codes are 128-bit random, short-lived, and single-use.
|
||||
- Provider bot tokens remain in `channels.*` and are never returned to the browser.
|
||||
- Stored per-connection credentials are encrypted. If stored credential material cannot be decrypted, DeerFlow treats it as unavailable instead of using corrupt secrets.
|
||||
- This implementation does not add public provider callback or webhook routes.
|
||||
@@ -31,7 +31,8 @@ Current injection format:
|
||||
|
||||
Token counting:
|
||||
- Uses `tiktoken` (`cl100k_base`) when available
|
||||
- Falls back to `len(text) // 4` if tokenizer import fails
|
||||
- Falls back to a network-free CJK-aware character estimate if tokenizer import or encoding load fails
|
||||
(CJK characters count as ~2 chars/token, other characters as ~4 chars/token)
|
||||
|
||||
## Known Gap
|
||||
|
||||
|
||||
@@ -586,7 +586,11 @@ def _get_memory_context(agent_name: str | None = None, *, app_config: AppConfig
|
||||
return ""
|
||||
|
||||
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
|
||||
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
|
||||
memory_content = format_memory_for_injection(
|
||||
memory_data,
|
||||
max_tokens=config.max_injection_tokens,
|
||||
use_tiktoken=(config.token_counting == "tiktoken"),
|
||||
)
|
||||
|
||||
if not memory_content.strip():
|
||||
return ""
|
||||
|
||||
@@ -5,7 +5,9 @@ from __future__ import annotations
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from typing import Any
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -169,7 +171,26 @@ Return ONLY valid JSON."""
|
||||
# subsequent calls are a dict lookup (no network I/O). Pre-warming at
|
||||
# startup via :func:`warm_tiktoken_cache` avoids blocking a request on the
|
||||
# (potentially slow) first ``get_encoding`` call.
|
||||
_tiktoken_encoding_cache: dict[str, tiktoken.Encoding] = {}
|
||||
#
|
||||
# A *failed* load is cached as a ``(None, monotonic_timestamp)`` tuple so that
|
||||
# a network-restricted environment does not re-attempt the blocking BPE
|
||||
# download on every subsequent call. After ``_TIKTOKEN_RETRY_COOLDOWN_S`` the
|
||||
# failure is allowed to expire so a transient network outage can self-heal back
|
||||
# to accurate tiktoken counting without a process restart. A load already in
|
||||
# progress is cached as ``_TIKTOKEN_ENCODING_LOADING`` so concurrent callers
|
||||
# fall back immediately instead of spawning more blocking
|
||||
# ``tiktoken.get_encoding`` threads. Use the ``memory.token_counting: char``
|
||||
# config to skip tiktoken entirely.
|
||||
_TIKTOKEN_ENCODING_MISSING = object()
|
||||
_TIKTOKEN_ENCODING_LOADING = object()
|
||||
# Cooldown before a *failed* tiktoken load is re-attempted. This is an internal
|
||||
# tuning constant rather than a user-facing config: it only affects how quickly
|
||||
# the default ``tiktoken`` mode self-heals after a transient network outage.
|
||||
# Deployments that want to avoid tiktoken's network dependency entirely should
|
||||
# set ``memory.token_counting: char`` instead of tuning this value.
|
||||
_TIKTOKEN_RETRY_COOLDOWN_S = 600.0
|
||||
_tiktoken_encoding_cache: dict[str, Any] = {}
|
||||
_tiktoken_encoding_cache_lock = threading.Lock()
|
||||
|
||||
|
||||
def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encoding | None:
|
||||
@@ -181,44 +202,91 @@ def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encod
|
||||
download can block for tens of minutes before the OS TCP timeout kicks in.
|
||||
The caller must therefore be prepared for this to block and should run it
|
||||
off the event loop (e.g. via ``asyncio.to_thread``).
|
||||
|
||||
A failed load is remembered (with a timestamp) so subsequent calls fall
|
||||
back immediately to character-based estimation instead of re-triggering the
|
||||
blocking download. The failure expires after ``_TIKTOKEN_RETRY_COOLDOWN_S``
|
||||
so a transient outage can self-heal without a restart. A load already in
|
||||
progress is also remembered so that a timed-out caller does not leave a
|
||||
window where later requests start more blocking ``get_encoding`` calls.
|
||||
"""
|
||||
if not TIKTOKEN_AVAILABLE:
|
||||
return None
|
||||
|
||||
cached = _tiktoken_encoding_cache.get(encoding_name)
|
||||
if cached is not None:
|
||||
return cached
|
||||
with _tiktoken_encoding_cache_lock:
|
||||
cached = _tiktoken_encoding_cache.get(encoding_name, _TIKTOKEN_ENCODING_MISSING)
|
||||
if cached is _TIKTOKEN_ENCODING_LOADING:
|
||||
return None
|
||||
if isinstance(cached, tuple):
|
||||
# Cached failure: (None, failed_at). Retry only after cooldown.
|
||||
_, failed_at = cached
|
||||
if time.monotonic() - failed_at < _TIKTOKEN_RETRY_COOLDOWN_S:
|
||||
return None
|
||||
cached = _TIKTOKEN_ENCODING_MISSING
|
||||
if cached is not _TIKTOKEN_ENCODING_MISSING:
|
||||
return cast("tiktoken.Encoding", cached)
|
||||
_tiktoken_encoding_cache[encoding_name] = _TIKTOKEN_ENCODING_LOADING
|
||||
|
||||
try:
|
||||
encoding = tiktoken.get_encoding(encoding_name)
|
||||
_tiktoken_encoding_cache[encoding_name] = encoding
|
||||
return encoding
|
||||
except Exception:
|
||||
logger.warning("Failed to load tiktoken encoding %r; falling back to char-based estimation", encoding_name, exc_info=True)
|
||||
with _tiktoken_encoding_cache_lock:
|
||||
_tiktoken_encoding_cache[encoding_name] = (None, time.monotonic())
|
||||
return None
|
||||
|
||||
with _tiktoken_encoding_cache_lock:
|
||||
_tiktoken_encoding_cache[encoding_name] = encoding
|
||||
return encoding
|
||||
|
||||
def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
|
||||
|
||||
def _char_based_token_estimate(text: str) -> int:
|
||||
"""Network-free token estimate that accounts for CJK density.
|
||||
|
||||
The plain ``len(text) // 4`` heuristic is reasonable for English/code
|
||||
(~4 chars per token) but significantly under-estimates token counts for
|
||||
Chinese, Japanese, and Korean text, where the ratio is closer to 1.5-2
|
||||
characters per token. Counting CJK characters separately (~2 chars per
|
||||
token) avoids over-filling the injection budget for CJK-heavy memory
|
||||
content.
|
||||
"""
|
||||
cjk = sum(
|
||||
1
|
||||
for ch in text
|
||||
if "\u4e00" <= ch <= "\u9fff" # CJK Unified Ideographs
|
||||
or "\u3040" <= ch <= "\u30ff" # Hiragana + Katakana
|
||||
or "\uac00" <= ch <= "\ud7a3" # Hangul syllables
|
||||
)
|
||||
return (len(text) - cjk) // 4 + cjk // 2
|
||||
|
||||
|
||||
def _count_tokens(text: str, encoding_name: str = "cl100k_base", *, use_tiktoken: bool = True) -> int:
|
||||
"""Count tokens in text using tiktoken.
|
||||
|
||||
Args:
|
||||
text: The text to count tokens for.
|
||||
encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5).
|
||||
use_tiktoken: When ``False``, skip tiktoken entirely and use the
|
||||
network-free character-based estimate. This guarantees no BPE
|
||||
download is attempted (see ``memory.token_counting`` config).
|
||||
|
||||
Returns:
|
||||
The number of tokens in the text.
|
||||
"""
|
||||
if not use_tiktoken:
|
||||
return _char_based_token_estimate(text)
|
||||
|
||||
encoding = _get_tiktoken_encoding(encoding_name)
|
||||
if encoding is None:
|
||||
# Fallback to character-based estimation if tiktoken is not available
|
||||
# or the encoding failed to load.
|
||||
return len(text) // 4
|
||||
# Fallback to CJK-aware character estimation if tiktoken is not
|
||||
# available or the encoding failed to load.
|
||||
return _char_based_token_estimate(text)
|
||||
|
||||
try:
|
||||
return len(encoding.encode(text))
|
||||
except Exception:
|
||||
# Fallback to character-based estimation on error
|
||||
return len(text) // 4
|
||||
# Fallback to CJK-aware character estimation on error.
|
||||
return _char_based_token_estimate(text)
|
||||
|
||||
|
||||
def warm_tiktoken_cache() -> bool:
|
||||
@@ -248,12 +316,15 @@ def _coerce_confidence(value: Any, default: float = 0.0) -> float:
|
||||
return max(0.0, min(1.0, confidence))
|
||||
|
||||
|
||||
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str:
|
||||
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000, *, use_tiktoken: bool = True) -> str:
|
||||
"""Format memory data for injection into system prompt.
|
||||
|
||||
Args:
|
||||
memory_data: The memory data dictionary.
|
||||
max_tokens: Maximum tokens to use (counted via tiktoken for accuracy).
|
||||
use_tiktoken: When ``False``, all token counting uses the network-free
|
||||
character-based estimate instead of tiktoken (see
|
||||
``memory.token_counting`` config). Defaults to ``True``.
|
||||
|
||||
Returns:
|
||||
Formatted memory string for system prompt injection.
|
||||
@@ -315,10 +386,10 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
||||
# Compute token count for existing sections once, then account
|
||||
# incrementally for each fact line to avoid full-string re-tokenization.
|
||||
base_text = "\n\n".join(sections)
|
||||
base_tokens = _count_tokens(base_text) if base_text else 0
|
||||
base_tokens = _count_tokens(base_text, use_tiktoken=use_tiktoken) if base_text else 0
|
||||
# Account for the separator between existing sections and the facts section.
|
||||
facts_header = "Facts:\n"
|
||||
separator_tokens = _count_tokens("\n\n" + facts_header) if base_text else _count_tokens(facts_header)
|
||||
separator_tokens = _count_tokens("\n\n" + facts_header, use_tiktoken=use_tiktoken) if base_text else _count_tokens(facts_header, use_tiktoken=use_tiktoken)
|
||||
running_tokens = base_tokens + separator_tokens
|
||||
|
||||
fact_lines: list[str] = []
|
||||
@@ -339,7 +410,7 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
||||
|
||||
# Each additional line is preceded by a newline (except the first).
|
||||
line_text = ("\n" + line) if fact_lines else line
|
||||
line_tokens = _count_tokens(line_text)
|
||||
line_tokens = _count_tokens(line_text, use_tiktoken=use_tiktoken)
|
||||
|
||||
if running_tokens + line_tokens <= max_tokens:
|
||||
fact_lines.append(line)
|
||||
@@ -355,8 +426,9 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
||||
|
||||
result = "\n\n".join(sections)
|
||||
|
||||
# Use accurate token counting with tiktoken
|
||||
token_count = _count_tokens(result)
|
||||
# Use accurate token counting with tiktoken (or the char-based estimate
|
||||
# when use_tiktoken is False).
|
||||
token_count = _count_tokens(result, use_tiktoken=use_tiktoken)
|
||||
if token_count > max_tokens:
|
||||
# Truncate to fit within token limit
|
||||
# Estimate characters to remove based on token ratio
|
||||
|
||||
@@ -46,11 +46,6 @@ def _reminder_in_messages(messages: list[Any]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _completion_reminder_count(messages: list[Any]) -> int:
|
||||
"""Return the number of todo_completion_reminder HumanMessages in *messages*."""
|
||||
return sum(1 for msg in messages if isinstance(msg, HumanMessage) and getattr(msg, "name", None) == "todo_completion_reminder")
|
||||
|
||||
|
||||
def _format_todos(todos: list[Todo]) -> str:
|
||||
"""Format a list of Todo items into a human-readable string."""
|
||||
lines: list[str] = []
|
||||
|
||||
@@ -18,6 +18,27 @@ class ViewedImageData(TypedDict):
|
||||
mime_type: str
|
||||
|
||||
|
||||
def merge_sandbox(existing: SandboxState | None, new: SandboxState | None) -> SandboxState | None:
|
||||
"""Reducer for sandbox state - accepts idempotent writes only.
|
||||
|
||||
Multiple sandbox tools can initialize lazily in the same graph step and
|
||||
emit the same sandbox_id via Command(update=...). LangGraph needs an
|
||||
explicit reducer for that shared state key. Different sandbox ids in the
|
||||
same thread indicate a lifecycle/isolation bug, so fail closed instead of
|
||||
choosing one silently.
|
||||
"""
|
||||
if new is None:
|
||||
return existing
|
||||
if existing is None:
|
||||
return new
|
||||
|
||||
existing_id = existing.get("sandbox_id")
|
||||
new_id = new.get("sandbox_id")
|
||||
if existing_id == new_id:
|
||||
return existing
|
||||
raise ValueError(f"Conflicting sandbox state updates: {existing_id!r} != {new_id!r}")
|
||||
|
||||
|
||||
def merge_artifacts(existing: list[str] | None, new: list[str] | None) -> list[str]:
|
||||
"""Reducer for artifacts list - merges and deduplicates artifacts."""
|
||||
if existing is None:
|
||||
@@ -85,7 +106,7 @@ def merge_promoted(existing: PromotedTools | None, new: PromotedTools | None) ->
|
||||
|
||||
|
||||
class ThreadState(AgentState):
|
||||
sandbox: NotRequired[SandboxState | None]
|
||||
sandbox: Annotated[NotRequired[SandboxState | None], merge_sandbox]
|
||||
thread_data: NotRequired[ThreadDataState | None]
|
||||
title: NotRequired[str | None]
|
||||
artifacts: Annotated[list[str], merge_artifacts]
|
||||
|
||||
@@ -1141,6 +1141,7 @@ class DeerFlowClient:
|
||||
"fact_confidence_threshold": config.fact_confidence_threshold,
|
||||
"injection_enabled": config.injection_enabled,
|
||||
"max_injection_tokens": config.max_injection_tokens,
|
||||
"token_counting": config.token_counting,
|
||||
}
|
||||
|
||||
def get_memory_status(self) -> dict:
|
||||
|
||||
@@ -470,14 +470,32 @@ class AioSandboxProvider(SandboxProvider):
|
||||
|
||||
existing_id = self._thread_sandboxes[thread_id]
|
||||
if existing_id in self._sandboxes:
|
||||
suffix = " (post-lock check)" if post_lock else ""
|
||||
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}{suffix}")
|
||||
self._last_activity[existing_id] = time.time()
|
||||
return existing_id
|
||||
info = self._sandbox_infos.get(existing_id)
|
||||
else:
|
||||
del self._thread_sandboxes[thread_id]
|
||||
return None
|
||||
|
||||
del self._thread_sandboxes[thread_id]
|
||||
alive = self._check_tracked_sandbox_alive(existing_id, info) if info is not None else True
|
||||
if alive is False:
|
||||
self._drop_unhealthy_sandbox(
|
||||
existing_id,
|
||||
"in-process cache failed health check",
|
||||
expected_info=info,
|
||||
)
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
if self._thread_sandboxes.get(thread_id) != existing_id:
|
||||
return None
|
||||
if existing_id not in self._sandboxes:
|
||||
self._thread_sandboxes.pop(thread_id, None)
|
||||
return None
|
||||
|
||||
suffix = " (post-lock check)" if post_lock else ""
|
||||
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}{suffix}")
|
||||
self._last_activity[existing_id] = time.time()
|
||||
return existing_id
|
||||
|
||||
def _reclaim_warm_pool_sandbox(self, thread_id: str | None, sandbox_id: str, *, post_lock: bool = False) -> str | None:
|
||||
"""Promote a warm-pool sandbox back to active tracking if available."""
|
||||
if thread_id is None:
|
||||
@@ -487,7 +505,22 @@ class AioSandboxProvider(SandboxProvider):
|
||||
if sandbox_id not in self._warm_pool:
|
||||
return None
|
||||
|
||||
info, _ = self._warm_pool.pop(sandbox_id)
|
||||
info, _ = self._warm_pool[sandbox_id]
|
||||
|
||||
alive = self._check_tracked_sandbox_alive(sandbox_id, info)
|
||||
if alive is False:
|
||||
self._drop_unhealthy_sandbox(
|
||||
sandbox_id,
|
||||
"warm-pool cache failed health check",
|
||||
expected_info=info,
|
||||
)
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
warm_item = self._warm_pool.pop(sandbox_id, None)
|
||||
if warm_item is None:
|
||||
return None
|
||||
info, _ = warm_item
|
||||
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
|
||||
self._sandboxes[sandbox_id] = sandbox
|
||||
self._sandbox_infos[sandbox_id] = info
|
||||
@@ -527,6 +560,70 @@ class AioSandboxProvider(SandboxProvider):
|
||||
logger.info(f"Created sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}")
|
||||
return sandbox_id
|
||||
|
||||
def _check_tracked_sandbox_alive(self, sandbox_id: str, info: SandboxInfo) -> bool | None:
|
||||
"""Return whether a tracked sandbox appears alive, or None if unknown."""
|
||||
try:
|
||||
return self._backend.is_alive(info)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check sandbox {sandbox_id} health: {e}")
|
||||
return None
|
||||
|
||||
def _remove_tracked_sandbox(
|
||||
self,
|
||||
sandbox_id: str,
|
||||
*,
|
||||
expected_info: SandboxInfo | None = None,
|
||||
) -> tuple[Sandbox | None, SandboxInfo | None, bool]:
|
||||
"""Remove a sandbox from in-process tracking maps.
|
||||
|
||||
When expected_info is provided, removal only happens if the currently
|
||||
tracked active or warm-pool entry is the exact info object that was
|
||||
checked. This prevents a stale health-check result from deleting a
|
||||
freshly recreated sandbox with the same deterministic id.
|
||||
"""
|
||||
thread_ids_to_remove: list[str] = []
|
||||
|
||||
with self._lock:
|
||||
active_info = self._sandbox_infos.get(sandbox_id)
|
||||
warm_item = self._warm_pool.get(sandbox_id)
|
||||
warm_info = warm_item[0] if warm_item is not None else None
|
||||
if expected_info is not None and active_info is not expected_info and warm_info is not expected_info:
|
||||
return None, None, False
|
||||
|
||||
sandbox = self._sandboxes.pop(sandbox_id, None)
|
||||
info = self._sandbox_infos.pop(sandbox_id, None)
|
||||
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
|
||||
for tid in thread_ids_to_remove:
|
||||
del self._thread_sandboxes[tid]
|
||||
self._last_activity.pop(sandbox_id, None)
|
||||
if info is None and sandbox_id in self._warm_pool:
|
||||
info, _ = self._warm_pool.pop(sandbox_id)
|
||||
else:
|
||||
self._warm_pool.pop(sandbox_id, None)
|
||||
|
||||
return sandbox, info, True
|
||||
|
||||
def _drop_unhealthy_sandbox(self, sandbox_id: str, reason: str, *, expected_info: SandboxInfo | None = None) -> None:
|
||||
"""Remove and destroy a sandbox after a definitive failed health check."""
|
||||
sandbox, info, removed = self._remove_tracked_sandbox(sandbox_id, expected_info=expected_info)
|
||||
if not removed:
|
||||
logger.info(f"Skipped dropping sandbox {sandbox_id}: tracked info changed after health check")
|
||||
return
|
||||
|
||||
if sandbox is not None:
|
||||
try:
|
||||
sandbox.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing unhealthy sandbox {sandbox_id}: {e}")
|
||||
|
||||
if info is not None:
|
||||
try:
|
||||
self._backend.destroy(info)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error destroying unhealthy sandbox {sandbox_id}: {e}")
|
||||
|
||||
logger.warning(f"Dropped unhealthy sandbox {sandbox_id}: {reason}")
|
||||
|
||||
def _replica_count(self) -> tuple[int, int]:
|
||||
"""Return configured replicas and currently tracked sandbox count."""
|
||||
replicas = self._config.get("replicas", DEFAULT_REPLICAS)
|
||||
@@ -617,7 +714,7 @@ class AioSandboxProvider(SandboxProvider):
|
||||
|
||||
async def _acquire_internal_async(self, thread_id: str | None) -> str:
|
||||
"""Async counterpart to ``_acquire_internal``."""
|
||||
cached_id = self._reuse_in_process_sandbox(thread_id)
|
||||
cached_id = await asyncio.to_thread(self._reuse_in_process_sandbox, thread_id)
|
||||
if cached_id is not None:
|
||||
return cached_id
|
||||
|
||||
@@ -625,7 +722,7 @@ class AioSandboxProvider(SandboxProvider):
|
||||
sandbox_id = self._sandbox_id_for_thread(thread_id)
|
||||
|
||||
# ── Layer 1.5: Warm pool (container still running, no cold-start) ──
|
||||
reclaimed_id = self._reclaim_warm_pool_sandbox(thread_id, sandbox_id)
|
||||
reclaimed_id = await asyncio.to_thread(self._reclaim_warm_pool_sandbox, thread_id, sandbox_id)
|
||||
if reclaimed_id is not None:
|
||||
return reclaimed_id
|
||||
|
||||
@@ -681,7 +778,7 @@ class AioSandboxProvider(SandboxProvider):
|
||||
locked = True
|
||||
# Re-check in-process caches under the file lock in case another
|
||||
# thread in this process won the race while we were waiting.
|
||||
cached_id = self._recheck_cached_sandbox(thread_id, sandbox_id)
|
||||
cached_id = await asyncio.to_thread(self._recheck_cached_sandbox, thread_id, sandbox_id)
|
||||
if cached_id is not None:
|
||||
return cached_id
|
||||
|
||||
@@ -837,22 +934,7 @@ class AioSandboxProvider(SandboxProvider):
|
||||
Args:
|
||||
sandbox_id: The ID of the sandbox to destroy.
|
||||
"""
|
||||
info = None
|
||||
sandbox = None
|
||||
thread_ids_to_remove: list[str] = []
|
||||
|
||||
with self._lock:
|
||||
sandbox = self._sandboxes.pop(sandbox_id, None)
|
||||
info = self._sandbox_infos.pop(sandbox_id, None)
|
||||
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
|
||||
for tid in thread_ids_to_remove:
|
||||
del self._thread_sandboxes[tid]
|
||||
self._last_activity.pop(sandbox_id, None)
|
||||
# Also pull from warm pool if it was parked there
|
||||
if info is None and sandbox_id in self._warm_pool:
|
||||
info, _ = self._warm_pool.pop(sandbox_id)
|
||||
else:
|
||||
self._warm_pool.pop(sandbox_id, None)
|
||||
sandbox, info, _ = self._remove_tracked_sandbox(sandbox_id)
|
||||
|
||||
if sandbox is not None:
|
||||
# Defense-in-depth: close() already swallows its own errors; this
|
||||
|
||||
@@ -169,6 +169,24 @@ def _resolve_docker_bind_host(sandbox_host: str | None = None, bind_host: str |
|
||||
return "0.0.0.0"
|
||||
|
||||
|
||||
def _is_no_such_container_error(stderr: str, container_name: str) -> bool:
|
||||
"""Return True only when stderr definitively says the container does not exist.
|
||||
|
||||
Docker reports "No such object" / "No such container". Apple Container
|
||||
reports a generic "not found", so that phrase is only trusted when the
|
||||
message also names the inspected container (or refers to a
|
||||
container/object); transient failures whose text happens to contain
|
||||
"not found" (e.g. "command not found", "context not found") must stay on
|
||||
the raise path instead of being misread as a dead container.
|
||||
"""
|
||||
message = stderr.lower()
|
||||
if "no such object" in message or "no such container" in message:
|
||||
return True
|
||||
if "not found" not in message:
|
||||
return False
|
||||
return container_name.lower() in message or "container" in message or "object" in message
|
||||
|
||||
|
||||
class LocalContainerBackend(SandboxBackend):
|
||||
"""Backend that manages sandbox containers locally using Docker or Apple Container.
|
||||
|
||||
@@ -335,11 +353,21 @@ class LocalContainerBackend(SandboxBackend):
|
||||
sandbox_id: The deterministic sandbox ID (determines container name).
|
||||
|
||||
Returns:
|
||||
SandboxInfo if container found and healthy, None otherwise.
|
||||
SandboxInfo if container found and healthy, None otherwise. A
|
||||
failed runtime check (e.g. transient daemon error) also returns
|
||||
None — discovery must not adopt a container it cannot verify, and
|
||||
falling through to create keeps acquire recoverable instead of
|
||||
hard-failing on a hiccup.
|
||||
"""
|
||||
container_name = f"{self._container_prefix}-{sandbox_id}"
|
||||
|
||||
if not self._is_container_running(container_name):
|
||||
try:
|
||||
running = self._is_container_running(container_name)
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Could not verify container {container_name} during discovery; not adopting it: {e}")
|
||||
return None
|
||||
|
||||
if not running:
|
||||
return None
|
||||
|
||||
port = self._get_container_port(container_name)
|
||||
@@ -582,6 +610,13 @@ class LocalContainerBackend(SandboxBackend):
|
||||
|
||||
This enables cross-process container discovery — any process can detect
|
||||
containers started by another process via the deterministic container name.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the container runtime cannot answer the inspect
|
||||
query. A failed check is intentionally distinct from a
|
||||
definitive "container does not exist" result so callers do not
|
||||
destroy healthy containers during transient Docker/Container
|
||||
daemon failures.
|
||||
"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
@@ -590,9 +625,14 @@ class LocalContainerBackend(SandboxBackend):
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
return result.returncode == 0 and result.stdout.strip().lower() == "true"
|
||||
except (subprocess.CalledProcessError, subprocess.TimeoutExpired):
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
raise RuntimeError(f"Timed out checking container {container_name}") from exc
|
||||
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip().lower() == "true"
|
||||
if _is_no_such_container_error(result.stderr, container_name):
|
||||
return False
|
||||
raise RuntimeError(f"Failed to inspect container {container_name}: {result.stderr.strip()}")
|
||||
|
||||
def _get_container_port(self, container_name: str) -> int | None:
|
||||
"""Get the host port of a running container.
|
||||
|
||||
@@ -176,12 +176,16 @@ class RemoteSandboxBackend(SandboxBackend):
|
||||
f"{self._provisioner_url}/api/sandboxes/{sandbox_id}",
|
||||
timeout=10,
|
||||
)
|
||||
if resp.ok:
|
||||
data = resp.json()
|
||||
return data.get("status") == "Running"
|
||||
return False
|
||||
except requests.RequestException:
|
||||
except requests.RequestException as exc:
|
||||
raise RuntimeError(f"Provisioner health check failed for {sandbox_id}: {exc}") from exc
|
||||
|
||||
if resp.status_code == 404:
|
||||
return False
|
||||
if not resp.ok:
|
||||
raise RuntimeError(f"Provisioner health check failed for {sandbox_id}: HTTP {resp.status_code} {resp.text}")
|
||||
|
||||
data = resp.json()
|
||||
return data.get("status") == "Running"
|
||||
|
||||
def _provisioner_discover(self, sandbox_id: str) -> SandboxInfo | None:
|
||||
"""GET /api/sandboxes/{sandbox_id} → discover existing sandbox."""
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .tools import web_search_tool
|
||||
|
||||
__all__ = ["web_search_tool"]
|
||||
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
Web Search Tool - Search the web using the Brave Search API.
|
||||
|
||||
Brave Search provides web results from an independent search index via a
|
||||
REST API. An API key is required. Sign up at https://brave.com/search/api/
|
||||
to get one.
|
||||
|
||||
Unlike the DuckDuckGo ``backend: brave`` option (which scrapes results via the
|
||||
DDGS aggregator), this provider calls the official Brave Search API directly,
|
||||
giving structured results, authenticated quota, and a documented SLA.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import httpx
|
||||
from langchain.tools import tool
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BRAVE_ENDPOINT = "https://api.search.brave.com/res/v1/web/search"
|
||||
_DEFAULT_MAX_RESULTS = 5
|
||||
# Brave Search API caps the `count` parameter at 20 results per request.
|
||||
_BRAVE_MAX_COUNT = 20
|
||||
_api_key_warned = False
|
||||
|
||||
|
||||
def _get_api_key() -> str | None:
|
||||
config = get_app_config().get_tool_config("web_search")
|
||||
if config is not None:
|
||||
api_key = (config.model_extra or {}).get("api_key")
|
||||
if isinstance(api_key, str) and api_key.strip():
|
||||
return api_key
|
||||
return os.getenv("BRAVE_SEARCH_API_KEY")
|
||||
|
||||
|
||||
def _coerce_max_results(value: object, *, default: int = _DEFAULT_MAX_RESULTS) -> int:
|
||||
try:
|
||||
coerced = int(value)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning(
|
||||
"Invalid Brave Search max_results=%r; using default %s",
|
||||
value,
|
||||
default,
|
||||
)
|
||||
coerced = default
|
||||
|
||||
return max(1, min(coerced, _BRAVE_MAX_COUNT))
|
||||
|
||||
|
||||
@tool("web_search", parse_docstring=True)
|
||||
def web_search_tool(query: str, max_results: int = 5) -> str:
|
||||
"""Search the web for information using Brave Search.
|
||||
|
||||
Args:
|
||||
query: Search keywords describing what you want to find. Be specific for better results.
|
||||
max_results: Maximum number of search results to return. Default is 5.
|
||||
"""
|
||||
global _api_key_warned
|
||||
|
||||
config = get_app_config().get_tool_config("web_search")
|
||||
if config is not None and "max_results" in (config.model_extra or {}):
|
||||
max_results = config.model_extra["max_results"]
|
||||
|
||||
count = _coerce_max_results(max_results)
|
||||
|
||||
api_key = _get_api_key()
|
||||
if not api_key:
|
||||
if not _api_key_warned:
|
||||
_api_key_warned = True
|
||||
logger.warning("Brave Search API key is not set. Set BRAVE_SEARCH_API_KEY in your environment or provide api_key in config.yaml. Sign up at https://brave.com/search/api/")
|
||||
return json.dumps(
|
||||
{"error": "BRAVE_SEARCH_API_KEY is not configured", "query": query},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"X-Subscription-Token": api_key,
|
||||
"Accept": "application/json",
|
||||
}
|
||||
params = {"q": query, "count": count, "text_decorations": False}
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=30) as client:
|
||||
response = client.get(_BRAVE_ENDPOINT, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Brave Search API returned HTTP {e.response.status_code}: {e.response.text}")
|
||||
return json.dumps(
|
||||
{"error": f"Brave Search API error: HTTP {e.response.status_code}", "query": query},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Brave search failed: {type(e).__name__}: {e}")
|
||||
return json.dumps({"error": str(e), "query": query}, ensure_ascii=False)
|
||||
|
||||
web_results = (data.get("web") or {}).get("results", [])
|
||||
if not web_results:
|
||||
return json.dumps({"error": "No results found", "query": query}, ensure_ascii=False)
|
||||
|
||||
normalized_results = [
|
||||
{
|
||||
"title": r.get("title", ""),
|
||||
"url": r.get("url", ""),
|
||||
"content": r.get("description", ""),
|
||||
}
|
||||
for r in web_results
|
||||
]
|
||||
|
||||
output = {
|
||||
"query": query,
|
||||
"total_results": len(normalized_results),
|
||||
"results": normalized_results,
|
||||
}
|
||||
return json.dumps(output, indent=2, ensure_ascii=False)
|
||||
@@ -0,0 +1,4 @@
|
||||
from .browserless_client import BrowserlessClient
|
||||
from .tools import web_fetch_tool
|
||||
|
||||
__all__ = ["BrowserlessClient", "web_fetch_tool"]
|
||||
@@ -0,0 +1,98 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BrowserlessClient:
|
||||
"""Client for Browserless headless Chrome API."""
|
||||
|
||||
def __init__(self, base_url: str, token: str = "", timeout_s: float = 30) -> None:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.token = token
|
||||
self.timeout_s = timeout_s
|
||||
|
||||
async def fetch_html(
|
||||
self,
|
||||
url: str,
|
||||
wait_for_event: str = "",
|
||||
wait_for_timeout_ms: int = 0,
|
||||
wait_for_selector: str = "",
|
||||
wait_for_selector_timeout_ms: int = 5000,
|
||||
reject_resource_types: list[str] | None = None,
|
||||
reject_request_pattern: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Fetch the rendered HTML of a page using Browserless.
|
||||
|
||||
Only sends accepted parameters for the current Browserless API version.
|
||||
Sets a default navigation timeout (30s) via query param.
|
||||
|
||||
Args:
|
||||
url: The URL to fetch.
|
||||
wait_for_event: Wait for a page event (e.g. "networkidle", "load").
|
||||
wait_for_timeout_ms: Extra wait after page load.
|
||||
wait_for_selector: CSS selector to wait for.
|
||||
wait_for_selector_timeout_ms: Timeout for selector wait.
|
||||
reject_resource_types: Resource types to block (e.g. ["image"]).
|
||||
reject_request_pattern: URL patterns to block.
|
||||
|
||||
Returns:
|
||||
Rendered HTML content.
|
||||
"""
|
||||
payload: dict[str, Any] = {
|
||||
"url": url,
|
||||
}
|
||||
|
||||
if self.token:
|
||||
payload["token"] = self.token
|
||||
if wait_for_event:
|
||||
payload["waitForEvent"] = wait_for_event
|
||||
if wait_for_timeout_ms > 0:
|
||||
payload["waitForTimeout"] = wait_for_timeout_ms
|
||||
if wait_for_selector:
|
||||
payload["waitForSelector"] = {
|
||||
"selector": wait_for_selector,
|
||||
"timeout": wait_for_selector_timeout_ms,
|
||||
}
|
||||
if reject_resource_types:
|
||||
payload["rejectResourceTypes"] = reject_resource_types
|
||||
if reject_request_pattern:
|
||||
payload["rejectRequestPattern"] = reject_request_pattern
|
||||
|
||||
logger.debug(f"Fetching URL via Browserless: {url}")
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.timeout_s) as client:
|
||||
resp = await client.post(
|
||||
f"{self.base_url}/content",
|
||||
json=payload,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Cache-Control": "no-cache",
|
||||
},
|
||||
)
|
||||
|
||||
code = resp.status_code
|
||||
target_code = resp.headers.get("X-Response-Code", "")
|
||||
target_status = resp.headers.get("X-Response-Status", "")
|
||||
|
||||
logger.debug(f"Browserless response: code={code}, target_code={target_code}, target_status={target_status}")
|
||||
|
||||
if code != 200:
|
||||
return f"Error: Browserless HTTP {code}: {resp.text[:200]}"
|
||||
|
||||
html = resp.text
|
||||
if not html or not html.strip():
|
||||
return "Error: Browserless returned empty response"
|
||||
|
||||
return html
|
||||
|
||||
except httpx.TimeoutException:
|
||||
return f"Error: Browserless request timed out after {self.timeout_s}s"
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Browserless request failed: {e}")
|
||||
return f"Error: Browserless request failed: {e!s}"
|
||||
except Exception as e:
|
||||
logger.error(f"Browserless fetch failed: {e}")
|
||||
return f"Error: Browserless fetch failed: {e!s}"
|
||||
@@ -0,0 +1,85 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from langchain.tools import tool
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.utils.readability import ReadabilityExtractor
|
||||
|
||||
from .browserless_client import BrowserlessClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# readability_extractor runs CPU-bound parsing; always call via asyncio.to_thread
|
||||
_readability_extractor = ReadabilityExtractor()
|
||||
|
||||
|
||||
def _get_tool_config(tool_name: str) -> dict | None:
|
||||
"""Get tool config extras safely, returning None if not configured."""
|
||||
config = get_app_config().get_tool_config(tool_name)
|
||||
if config is None:
|
||||
return None
|
||||
extras = config.model_extra
|
||||
return extras if extras is not None else {}
|
||||
|
||||
|
||||
def _get_browserless_client() -> BrowserlessClient:
|
||||
cfg = _get_tool_config("web_fetch")
|
||||
base_url = "http://localhost:3032"
|
||||
token = ""
|
||||
timeout_s = 30.0
|
||||
if cfg is not None:
|
||||
base_url = cfg.get("base_url", base_url)
|
||||
token = cfg.get("token", token)
|
||||
raw = cfg.get("timeout_s", timeout_s)
|
||||
timeout_s = float(raw) if not isinstance(raw, float) else raw
|
||||
return BrowserlessClient(base_url=base_url, token=token, timeout_s=timeout_s)
|
||||
|
||||
|
||||
@tool("web_fetch", parse_docstring=True)
|
||||
async def web_fetch_tool(url: str) -> str:
|
||||
"""Fetch the contents of a web page at a given URL using Browserless (headless Chrome).
|
||||
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
|
||||
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
|
||||
Do NOT add www. to URLs that do NOT have them.
|
||||
URLs must include the schema: https://example.com is a valid URL while example.com is an invalid URL.
|
||||
|
||||
Args:
|
||||
url: The URL to fetch the contents of.
|
||||
"""
|
||||
try:
|
||||
cfg = _get_tool_config("web_fetch")
|
||||
|
||||
wait_for_event = ""
|
||||
wait_for_timeout_ms = 0
|
||||
wait_for_selector = ""
|
||||
wait_for_selector_timeout_ms = 5000
|
||||
reject_resource_types: list[str] | None = None
|
||||
reject_request_pattern: list[str] | None = None
|
||||
|
||||
if cfg is not None:
|
||||
wait_for_event = cfg.get("wait_for_event", wait_for_event)
|
||||
raw_wait = cfg.get("wait_for_timeout_ms", wait_for_timeout_ms)
|
||||
wait_for_timeout_ms = int(raw_wait) if not isinstance(raw_wait, int) else raw_wait
|
||||
wait_for_selector = cfg.get("wait_for_selector", wait_for_selector)
|
||||
|
||||
client = _get_browserless_client()
|
||||
html = await client.fetch_html(
|
||||
url=url,
|
||||
wait_for_event=wait_for_event,
|
||||
wait_for_timeout_ms=wait_for_timeout_ms,
|
||||
wait_for_selector=wait_for_selector,
|
||||
wait_for_selector_timeout_ms=wait_for_selector_timeout_ms,
|
||||
reject_resource_types=reject_resource_types,
|
||||
reject_request_pattern=reject_request_pattern,
|
||||
)
|
||||
|
||||
if html.startswith("Error:"):
|
||||
return html
|
||||
|
||||
article = await asyncio.to_thread(_readability_extractor.extract_article, html)
|
||||
return article.to_markdown()[:4096]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in web_fetch_tool: {e}")
|
||||
return f"Error: {str(e)}"
|
||||
@@ -0,0 +1,3 @@
|
||||
from .tools import web_search_tool
|
||||
|
||||
__all__ = ["web_search_tool"]
|
||||
@@ -0,0 +1,65 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SearxngClient:
|
||||
"""Client for SearXNG meta search engine API."""
|
||||
|
||||
def __init__(self, base_url: str) -> None:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 5,
|
||||
categories: list[str] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search the web using SearXNG.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
max_results: Maximum number of results to return.
|
||||
categories: Search categories to use.
|
||||
|
||||
Returns:
|
||||
List of search result dictionaries.
|
||||
"""
|
||||
params: dict[str, Any] = {
|
||||
"q": query,
|
||||
"format": "json",
|
||||
"language": "auto",
|
||||
"pageno": 1,
|
||||
}
|
||||
if max_results:
|
||||
params["limit"] = max_results
|
||||
if categories:
|
||||
params["categories"] = ",".join(categories)
|
||||
|
||||
logger.debug(f"Searching SearXNG at {self.base_url} with query: {query}")
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/search",
|
||||
params=params,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (compatible; DeerFlow/1.0)",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
results = data.get("results", [])
|
||||
return results[:max_results] if max_results else results
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"SearXNG search returned error status: {e}")
|
||||
raise
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"SearXNG search request failed: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"An unexpected error occurred during SearXNG search: {e}")
|
||||
raise
|
||||
@@ -0,0 +1,58 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from langchain.tools import tool
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
from .searxng_client import SearxngClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_tool_config(tool_name: str) -> dict | None:
|
||||
"""Get tool config extras safely, returning None if not configured."""
|
||||
config = get_app_config().get_tool_config(tool_name)
|
||||
if config is None:
|
||||
return None
|
||||
extras = config.model_extra
|
||||
return extras if extras is not None else {}
|
||||
|
||||
|
||||
def _get_searxng_client() -> SearxngClient:
|
||||
cfg = _get_tool_config("web_search")
|
||||
base_url = "http://localhost:8088"
|
||||
if cfg is not None:
|
||||
base_url = cfg.get("base_url", base_url)
|
||||
return SearxngClient(base_url=base_url)
|
||||
|
||||
|
||||
@tool("web_search", parse_docstring=True)
|
||||
async def web_search_tool(query: str) -> str:
|
||||
"""Search the web using SearXNG.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
"""
|
||||
try:
|
||||
cfg = _get_tool_config("web_search")
|
||||
max_results = 5
|
||||
if cfg is not None:
|
||||
raw = cfg.get("max_results", max_results)
|
||||
max_results = int(raw) if not isinstance(raw, int) else raw
|
||||
|
||||
client = _get_searxng_client()
|
||||
results = await client.search(query, max_results=max_results)
|
||||
|
||||
normalized = [
|
||||
{
|
||||
"title": r.get("title", ""),
|
||||
"url": r.get("url", ""),
|
||||
"snippet": r.get("content", ""),
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
return json.dumps(normalized, indent=2, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in web_search_tool: {e}")
|
||||
return json.dumps({"error": str(e), "query": query}, ensure_ascii=False)
|
||||
@@ -67,11 +67,13 @@ def resolve_agent_dir(name: str, *, user_id: str | None = None) -> Path:
|
||||
paths = get_paths()
|
||||
effective_user = user_id or get_effective_user_id()
|
||||
user_path = paths.user_agent_dir(effective_user, name)
|
||||
if user_path.exists():
|
||||
# Require config.yaml to confirm this is a genuine agent directory,
|
||||
# not a leftover from memory/storage writes (see #3390).
|
||||
if user_path.exists() and (user_path / "config.yaml").exists():
|
||||
return user_path
|
||||
|
||||
legacy_path = paths.agent_dir(name)
|
||||
if legacy_path.exists():
|
||||
if legacy_path.exists() and (legacy_path / "config.yaml").exists():
|
||||
return legacy_path
|
||||
|
||||
return user_path
|
||||
|
||||
@@ -11,6 +11,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from deerflow.config.acp_config import ACPAgentConfig, load_acp_config_from_dict
|
||||
from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
|
||||
from deerflow.config.database_config import DatabaseConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
@@ -116,6 +117,13 @@ class AppConfig(BaseModel):
|
||||
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
|
||||
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
||||
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
|
||||
channel_connections: ChannelConnectionsConfig = Field(
|
||||
default_factory=ChannelConnectionsConfig,
|
||||
description=format_field_description(
|
||||
"channel_connections",
|
||||
field_doc="User-facing IM channel connection configuration.",
|
||||
),
|
||||
)
|
||||
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
|
||||
safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
"""Configuration for user-owned IM channel connections."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SlackChannelConnectionConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
|
||||
@property
|
||||
def configured(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class TelegramChannelConnectionConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
bot_username: str = ""
|
||||
|
||||
@property
|
||||
def configured(self) -> bool:
|
||||
return bool(self.bot_username)
|
||||
|
||||
|
||||
class DiscordChannelConnectionConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
|
||||
@property
|
||||
def configured(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class BindingCodeChannelConnectionConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
|
||||
@property
|
||||
def configured(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class ChannelConnectionsConfig(BaseModel):
|
||||
"""Top-level config for browser-connectable IM channels."""
|
||||
|
||||
enabled: bool = False
|
||||
slack: SlackChannelConnectionConfig = Field(default_factory=SlackChannelConnectionConfig)
|
||||
telegram: TelegramChannelConnectionConfig = Field(default_factory=TelegramChannelConnectionConfig)
|
||||
discord: DiscordChannelConnectionConfig = Field(default_factory=DiscordChannelConnectionConfig)
|
||||
feishu: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
|
||||
dingtalk: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
|
||||
wechat: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
|
||||
wecom: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
|
||||
|
||||
def provider_status(self, provider: str) -> dict[str, bool]:
|
||||
config = getattr(self, provider, None)
|
||||
if config is None:
|
||||
return {"enabled": False, "configured": False}
|
||||
enabled = bool(config.enabled)
|
||||
return {
|
||||
"enabled": enabled,
|
||||
"configured": enabled and bool(config.configured),
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Configuration for memory mechanism."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -60,6 +62,17 @@ class MemoryConfig(BaseModel):
|
||||
le=8000,
|
||||
description="Maximum tokens to use for memory injection",
|
||||
)
|
||||
token_counting: Literal["tiktoken", "char"] = Field(
|
||||
default="tiktoken",
|
||||
description=(
|
||||
"Token counting strategy for memory-injection budgeting. "
|
||||
"'tiktoken' is accurate but the encoding's BPE data may be "
|
||||
"downloaded from a public network endpoint on first use, which "
|
||||
"can block for a long time in network-restricted environments "
|
||||
"(see issue #3402/#3429). 'char' uses a network-free "
|
||||
"CJK-aware character-based estimate and never touches tiktoken."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@@ -14,6 +15,8 @@ _SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||
_UNSAFE_USER_ID_CHAR_RE = re.compile(r"[^A-Za-z0-9_\-]")
|
||||
_SAFE_USER_ID_DIGEST_HEX_LEN = 16
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _default_local_base_dir() -> Path:
|
||||
"""Return the caller project's writable DeerFlow state directory."""
|
||||
@@ -47,7 +50,13 @@ def make_safe_user_id(raw: str) -> str:
|
||||
sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw)
|
||||
if sanitized == raw:
|
||||
return raw
|
||||
digest = hashlib.sha1(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN]
|
||||
digest = hashlib.sha256(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN]
|
||||
return f"{sanitized}-{digest}"
|
||||
|
||||
|
||||
def _legacy_safe_user_id(raw: str, sanitized: str) -> str:
|
||||
"""Bucket name produced by the previous (SHA-1) digest revision for ``raw``."""
|
||||
digest = hashlib.sha1(raw.encode("utf-8"), usedforsecurity=False).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN]
|
||||
return f"{sanitized}-{digest}"
|
||||
|
||||
|
||||
@@ -172,6 +181,32 @@ class Paths:
|
||||
"""Directory for a specific user: `{base_dir}/users/{user_id}/`."""
|
||||
return self.base_dir / "users" / _validate_user_id(user_id)
|
||||
|
||||
def prepare_user_dir_for_raw_id(self, raw_user_id: str) -> str:
|
||||
"""Return the safe user ID and migrate this ID's legacy unsafe-id bucket.
|
||||
|
||||
A previous branch revision used SHA-1 for unsafe external user IDs.
|
||||
New IDs use SHA-256; the legacy bucket name is recomputed from the same
|
||||
raw ID, so only this user's own old bucket can ever be moved — a
|
||||
different raw ID sharing the sanitized prefix produces a different
|
||||
legacy digest and is never touched.
|
||||
"""
|
||||
safe_user_id = make_safe_user_id(raw_user_id)
|
||||
sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw_user_id)
|
||||
if safe_user_id == raw_user_id:
|
||||
return safe_user_id
|
||||
|
||||
users_dir = self.base_dir / "users"
|
||||
target_dir = users_dir / safe_user_id
|
||||
legacy_dir = users_dir / _legacy_safe_user_id(raw_user_id, sanitized)
|
||||
try:
|
||||
if target_dir.exists() or not legacy_dir.is_dir():
|
||||
return safe_user_id
|
||||
legacy_dir.rename(target_dir)
|
||||
logger.info("Migrated legacy unsafe-id user directory to the current digest format")
|
||||
except OSError:
|
||||
logger.exception("Failed to migrate legacy unsafe-id user directory")
|
||||
return safe_user_id
|
||||
|
||||
def user_memory_file(self, user_id: str) -> Path:
|
||||
"""Per-user memory file: `{base_dir}/users/{user_id}/memory.json`."""
|
||||
return self.user_dir(user_id) / "memory.json"
|
||||
|
||||
@@ -56,6 +56,9 @@ STARTUP_ONLY_FIELDS: dict[str, str] = {
|
||||
# startup and the live channel clients are not rebuilt on
|
||||
# config.yaml edits.
|
||||
"channels": ("start_channel_service() is invoked once during startup; the live IM channel clients (Feishu, Slack, Telegram, DingTalk) are not rebuilt when channels.* changes."),
|
||||
"channel_connections": (
|
||||
"start_channel_service() wires the connection repository and channel workers once at startup, and the channel-connections router caches the merged provider config on app.state; channel_connections.* edits need a restart."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,20 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
class VolumeMountConfig(BaseModel):
|
||||
"""Configuration for a volume mount."""
|
||||
|
||||
host_path: str = Field(..., description="Path on the host machine")
|
||||
host_path: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"Source path for the mount. Resolution depends on the active provider: "
|
||||
"``LocalSandboxProvider`` checks this path from the gateway process — in "
|
||||
"``make dev`` that is the host machine, but in Docker deployments "
|
||||
"(``make up`` / docker-compose) it is the path *inside* the "
|
||||
"``deer-flow-gateway`` container, so the host directory must also be "
|
||||
"bind-mounted into the gateway service for the mount to take effect. "
|
||||
"``AioSandboxProvider`` (DooD) passes this value straight to ``docker -v`` "
|
||||
"for the sandbox container, where it is resolved by the host Docker daemon "
|
||||
"from the host machine's perspective."
|
||||
),
|
||||
)
|
||||
container_path: str = Field(..., description="Path inside the container")
|
||||
read_only: bool = Field(default=False, description="Whether the mount is read-only")
|
||||
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
"""User-owned IM channel connection persistence."""
|
||||
|
||||
from deerflow.persistence.channel_connections.model import (
|
||||
ChannelConnectionRow,
|
||||
ChannelConversationRow,
|
||||
ChannelCredentialRow,
|
||||
ChannelOAuthStateRow,
|
||||
)
|
||||
from deerflow.persistence.channel_connections.sql import (
|
||||
ChannelConnectionRepository,
|
||||
ChannelCredentialCipher,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ChannelConnectionRepository",
|
||||
"ChannelConnectionRow",
|
||||
"ChannelConversationRow",
|
||||
"ChannelCredentialCipher",
|
||||
"ChannelCredentialRow",
|
||||
"ChannelOAuthStateRow",
|
||||
]
|
||||
@@ -0,0 +1,111 @@
|
||||
"""ORM models for user-owned IM channel connections."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import JSON, DateTime, ForeignKey, Index, Integer, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class ChannelConnectionRow(Base):
|
||||
__tablename__ = "channel_connections"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
|
||||
status: Mapped[str] = mapped_column(String(32), nullable=False, default="connected")
|
||||
|
||||
external_account_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
external_account_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
workspace_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
workspace_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
bot_user_id: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
|
||||
scopes_json: Mapped[list] = mapped_column(JSON, default=list)
|
||||
capabilities_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now)
|
||||
last_seen_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
last_error_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"owner_user_id",
|
||||
"provider",
|
||||
"external_account_id",
|
||||
"workspace_id",
|
||||
name="uq_channel_connection_owner_provider_identity",
|
||||
),
|
||||
Index("idx_channel_connections_event_lookup", "provider", "workspace_id", "bot_user_id"),
|
||||
)
|
||||
|
||||
|
||||
class ChannelCredentialRow(Base):
|
||||
__tablename__ = "channel_credentials"
|
||||
|
||||
connection_id: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
ForeignKey("channel_connections.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
encrypted_access_token: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
encrypted_refresh_token: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
token_type: Mapped[str | None] = mapped_column(String(32), nullable=True)
|
||||
expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
refresh_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
encrypted_extra_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
version: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now)
|
||||
|
||||
|
||||
class ChannelOAuthStateRow(Base):
|
||||
__tablename__ = "channel_oauth_states"
|
||||
|
||||
state_hash: Mapped[str] = mapped_column(String(128), primary_key=True)
|
||||
owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
|
||||
code_verifier_encrypted: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
nonce_hash: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
redirect_after: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
requested_scopes_json: Mapped[list] = mapped_column(JSON, default=list)
|
||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
consumed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now)
|
||||
|
||||
|
||||
class ChannelConversationRow(Base):
|
||||
__tablename__ = "channel_conversations"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
connection_id: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
ForeignKey("channel_connections.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
|
||||
external_conversation_id: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
external_topic_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"connection_id",
|
||||
"external_conversation_id",
|
||||
"external_topic_id",
|
||||
name="uq_channel_conversation_connection_external",
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,387 @@
|
||||
"""SQL repository for user-owned IM channel connections."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.channel_connections.model import (
|
||||
ChannelConnectionRow,
|
||||
ChannelConversationRow,
|
||||
ChannelCredentialRow,
|
||||
ChannelOAuthStateRow,
|
||||
)
|
||||
from deerflow.utils.time import coerce_iso
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChannelCredentialCipher:
|
||||
"""Encrypts provider credentials before they are persisted."""
|
||||
|
||||
def __init__(self, fernet: Fernet) -> None:
|
||||
self._fernet = fernet
|
||||
|
||||
@classmethod
|
||||
def from_key(cls, key: str) -> ChannelCredentialCipher:
|
||||
digest = hashlib.sha256(key.encode("utf-8")).digest()
|
||||
return cls(Fernet(base64.urlsafe_b64encode(digest)))
|
||||
|
||||
def encrypt_text(self, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return "fernet:v1:" + self._fernet.encrypt(value.encode("utf-8")).decode("ascii")
|
||||
|
||||
def decrypt_text(self, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
token = value.removeprefix("fernet:v1:")
|
||||
return self._fernet.decrypt(token.encode("ascii")).decode("utf-8")
|
||||
|
||||
|
||||
class ChannelConnectionRepository:
|
||||
"""Persistence facade for channel connections, credentials, and conversations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
*,
|
||||
cipher: ChannelCredentialCipher | None = None,
|
||||
) -> None:
|
||||
self.session_factory = session_factory
|
||||
self._cipher = cipher
|
||||
|
||||
async def close(self) -> None:
|
||||
from deerflow.persistence.engine import close_engine
|
||||
|
||||
await close_engine()
|
||||
|
||||
@staticmethod
|
||||
def _new_id() -> str:
|
||||
return uuid.uuid4().hex
|
||||
|
||||
@staticmethod
|
||||
def _normalize_optional_identity(value: str | None) -> str:
|
||||
return value or ""
|
||||
|
||||
@staticmethod
|
||||
def _coerce_datetime(value: datetime | None) -> datetime | None:
|
||||
if value is None or value.tzinfo is not None:
|
||||
return value
|
||||
return value.replace(tzinfo=UTC)
|
||||
|
||||
def _encrypt_optional_secret(self, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if self._cipher is None:
|
||||
raise RuntimeError("channel connection encryption key is required")
|
||||
return self._cipher.encrypt_text(value)
|
||||
|
||||
@staticmethod
|
||||
def _connection_to_dict(row: ChannelConnectionRow) -> dict[str, Any]:
|
||||
data = row.to_dict()
|
||||
data["external_account_id"] = data["external_account_id"] or None
|
||||
data["workspace_id"] = data["workspace_id"] or None
|
||||
data["scopes"] = data.pop("scopes_json") or []
|
||||
data["capabilities"] = data.pop("capabilities_json") or {}
|
||||
data["metadata"] = data.pop("metadata_json") or {}
|
||||
for key in ("created_at", "updated_at", "last_seen_at", "last_error_at"):
|
||||
value = data.get(key)
|
||||
if isinstance(value, datetime):
|
||||
data[key] = coerce_iso(value)
|
||||
return data
|
||||
|
||||
async def upsert_connection(
|
||||
self,
|
||||
*,
|
||||
owner_user_id: str,
|
||||
provider: str,
|
||||
external_account_id: str | None = None,
|
||||
external_account_name: str | None = None,
|
||||
workspace_id: str | None = None,
|
||||
workspace_name: str | None = None,
|
||||
bot_user_id: str | None = None,
|
||||
scopes: list[str] | None = None,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str = "connected",
|
||||
) -> dict[str, Any]:
|
||||
external_account_id_value = self._normalize_optional_identity(external_account_id)
|
||||
workspace_id_value = self._normalize_optional_identity(workspace_id)
|
||||
|
||||
def _apply(row: ChannelConnectionRow) -> None:
|
||||
row.status = status
|
||||
row.external_account_name = external_account_name
|
||||
row.workspace_name = workspace_name
|
||||
row.bot_user_id = bot_user_id
|
||||
row.scopes_json = list(scopes or [])
|
||||
row.capabilities_json = dict(capabilities or {})
|
||||
row.metadata_json = dict(metadata or {})
|
||||
|
||||
stmt = select(ChannelConnectionRow).where(
|
||||
ChannelConnectionRow.owner_user_id == owner_user_id,
|
||||
ChannelConnectionRow.provider == provider,
|
||||
ChannelConnectionRow.external_account_id == external_account_id_value,
|
||||
ChannelConnectionRow.workspace_id == workspace_id_value,
|
||||
)
|
||||
async with self.session_factory() as session:
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
row = ChannelConnectionRow(
|
||||
id=self._new_id(),
|
||||
owner_user_id=owner_user_id,
|
||||
provider=provider,
|
||||
external_account_id=external_account_id_value,
|
||||
workspace_id=workspace_id_value,
|
||||
)
|
||||
session.add(row)
|
||||
|
||||
_apply(row)
|
||||
try:
|
||||
await session.commit()
|
||||
except IntegrityError:
|
||||
# A concurrent writer inserted the same identity first; retry as
|
||||
# an update of that row.
|
||||
await session.rollback()
|
||||
row = (await session.execute(stmt)).scalar_one()
|
||||
_apply(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._connection_to_dict(row)
|
||||
|
||||
async def list_connections(self, owner_user_id: str) -> list[dict[str, Any]]:
|
||||
async with self.session_factory() as session:
|
||||
result = await session.execute(select(ChannelConnectionRow).where(ChannelConnectionRow.owner_user_id == owner_user_id).order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc()))
|
||||
return [self._connection_to_dict(row) for row in result.scalars()]
|
||||
|
||||
async def disconnect_connection(self, *, connection_id: str, owner_user_id: str) -> bool:
|
||||
async with self.session_factory() as session:
|
||||
row = await session.get(ChannelConnectionRow, connection_id)
|
||||
if row is None or row.owner_user_id != owner_user_id:
|
||||
return False
|
||||
|
||||
row.status = "revoked"
|
||||
credential = await session.get(ChannelCredentialRow, connection_id)
|
||||
if credential is not None:
|
||||
await session.delete(credential)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def store_credentials(
|
||||
self,
|
||||
connection_id: str,
|
||||
*,
|
||||
access_token: str | None,
|
||||
refresh_token: str | None = None,
|
||||
token_type: str | None = None,
|
||||
expires_at: datetime | None = None,
|
||||
refresh_expires_at: datetime | None = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
if self._cipher is None:
|
||||
raise RuntimeError("channel connection encryption key is required")
|
||||
async with self.session_factory() as session:
|
||||
row = await session.get(ChannelCredentialRow, connection_id)
|
||||
if row is None:
|
||||
row = ChannelCredentialRow(connection_id=connection_id)
|
||||
session.add(row)
|
||||
row.encrypted_access_token = self._cipher.encrypt_text(access_token)
|
||||
row.encrypted_refresh_token = self._cipher.encrypt_text(refresh_token)
|
||||
row.token_type = token_type
|
||||
row.expires_at = expires_at
|
||||
row.refresh_expires_at = refresh_expires_at
|
||||
row.encrypted_extra_json = self._cipher.encrypt_text(json.dumps(extra or {}, ensure_ascii=False))
|
||||
row.version = (row.version or 0) + 1
|
||||
await session.commit()
|
||||
|
||||
async def get_credentials(self, connection_id: str) -> dict[str, Any] | None:
|
||||
if self._cipher is None:
|
||||
return None
|
||||
async with self.session_factory() as session:
|
||||
row = await session.get(ChannelCredentialRow, connection_id)
|
||||
if row is None:
|
||||
return None
|
||||
try:
|
||||
extra_raw = self._cipher.decrypt_text(row.encrypted_extra_json)
|
||||
return {
|
||||
"connection_id": row.connection_id,
|
||||
"access_token": self._cipher.decrypt_text(row.encrypted_access_token),
|
||||
"refresh_token": self._cipher.decrypt_text(row.encrypted_refresh_token),
|
||||
"token_type": row.token_type,
|
||||
"expires_at": self._coerce_datetime(row.expires_at),
|
||||
"refresh_expires_at": self._coerce_datetime(row.refresh_expires_at),
|
||||
"extra": json.loads(extra_raw) if extra_raw else {},
|
||||
}
|
||||
except (InvalidToken, UnicodeError, json.JSONDecodeError):
|
||||
logger.warning(
|
||||
"Unable to decrypt channel connection credentials; treating credentials as unavailable",
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def hash_state(state: str) -> str:
|
||||
return hashlib.sha256(state.encode("utf-8")).hexdigest()
|
||||
|
||||
async def create_oauth_state(
|
||||
self,
|
||||
*,
|
||||
owner_user_id: str,
|
||||
provider: str,
|
||||
state: str,
|
||||
expires_at: datetime,
|
||||
code_verifier: str | None = None,
|
||||
nonce_hash: str | None = None,
|
||||
redirect_after: str | None = None,
|
||||
requested_scopes: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
row = ChannelOAuthStateRow(
|
||||
state_hash=self.hash_state(state),
|
||||
owner_user_id=owner_user_id,
|
||||
provider=provider,
|
||||
code_verifier_encrypted=self._encrypt_optional_secret(code_verifier),
|
||||
nonce_hash=nonce_hash,
|
||||
redirect_after=redirect_after,
|
||||
requested_scopes_json=list(requested_scopes or []),
|
||||
metadata_json=dict(metadata or {}),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
async with self.session_factory() as session:
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
|
||||
async def count_oauth_states(self, *, owner_user_id: str, provider: str) -> int:
|
||||
async with self.session_factory() as session:
|
||||
result = await session.execute(
|
||||
select(func.count())
|
||||
.select_from(ChannelOAuthStateRow)
|
||||
.where(
|
||||
ChannelOAuthStateRow.owner_user_id == owner_user_id,
|
||||
ChannelOAuthStateRow.provider == provider,
|
||||
)
|
||||
)
|
||||
return int(result.scalar_one())
|
||||
|
||||
async def consume_oauth_state(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
state: str,
|
||||
now: datetime | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
current_time = now or datetime.now(UTC)
|
||||
state_hash = self.hash_state(state)
|
||||
async with self.session_factory() as session:
|
||||
await session.execute(delete(ChannelOAuthStateRow).where(ChannelOAuthStateRow.expires_at < current_time))
|
||||
row = await session.get(ChannelOAuthStateRow, state_hash)
|
||||
if row is None or row.provider != provider or row.consumed_at is not None:
|
||||
await session.commit()
|
||||
return None
|
||||
expires_at = self._coerce_datetime(row.expires_at)
|
||||
if expires_at is not None and expires_at < current_time:
|
||||
await session.commit()
|
||||
return None
|
||||
|
||||
# Conditional UPDATE so two concurrent workers cannot both consume
|
||||
# the same binding code: only the writer that flips consumed_at
|
||||
# from NULL wins.
|
||||
result = await session.execute(
|
||||
update(ChannelOAuthStateRow)
|
||||
.where(
|
||||
ChannelOAuthStateRow.state_hash == state_hash,
|
||||
ChannelOAuthStateRow.consumed_at.is_(None),
|
||||
)
|
||||
.values(consumed_at=current_time)
|
||||
)
|
||||
await session.commit()
|
||||
if result.rowcount != 1:
|
||||
return None
|
||||
return {
|
||||
"owner_user_id": row.owner_user_id,
|
||||
"provider": row.provider,
|
||||
"requested_scopes": row.requested_scopes_json or [],
|
||||
"metadata": row.metadata_json or {},
|
||||
"redirect_after": row.redirect_after,
|
||||
}
|
||||
|
||||
async def find_connection_by_external_identity(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
external_account_id: str,
|
||||
workspace_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
async with self.session_factory() as session:
|
||||
result = await session.execute(
|
||||
select(ChannelConnectionRow)
|
||||
.where(
|
||||
ChannelConnectionRow.provider == provider,
|
||||
ChannelConnectionRow.external_account_id == self._normalize_optional_identity(external_account_id),
|
||||
ChannelConnectionRow.workspace_id == self._normalize_optional_identity(workspace_id),
|
||||
ChannelConnectionRow.status == "connected",
|
||||
)
|
||||
.order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc())
|
||||
.limit(1)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
return self._connection_to_dict(row) if row is not None else None
|
||||
|
||||
async def set_thread_id(
|
||||
self,
|
||||
*,
|
||||
connection_id: str,
|
||||
owner_user_id: str,
|
||||
provider: str,
|
||||
external_conversation_id: str,
|
||||
thread_id: str,
|
||||
external_topic_id: str | None = None,
|
||||
) -> None:
|
||||
topic_id = external_topic_id or ""
|
||||
async with self.session_factory() as session:
|
||||
stmt = select(ChannelConversationRow).where(
|
||||
ChannelConversationRow.connection_id == connection_id,
|
||||
ChannelConversationRow.external_conversation_id == external_conversation_id,
|
||||
ChannelConversationRow.external_topic_id == topic_id,
|
||||
)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
row = ChannelConversationRow(
|
||||
id=self._new_id(),
|
||||
connection_id=connection_id,
|
||||
owner_user_id=owner_user_id,
|
||||
provider=provider,
|
||||
external_conversation_id=external_conversation_id,
|
||||
external_topic_id=topic_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
session.add(row)
|
||||
else:
|
||||
row.thread_id = thread_id
|
||||
row.owner_user_id = owner_user_id
|
||||
row.provider = provider
|
||||
await session.commit()
|
||||
|
||||
async def get_thread_id(
|
||||
self,
|
||||
connection_id: str,
|
||||
external_conversation_id: str,
|
||||
external_topic_id: str | None = None,
|
||||
) -> str | None:
|
||||
async with self.session_factory() as session:
|
||||
stmt = select(ChannelConversationRow.thread_id).where(
|
||||
ChannelConversationRow.connection_id == connection_id,
|
||||
ChannelConversationRow.external_conversation_id == external_conversation_id,
|
||||
ChannelConversationRow.external_topic_id == (external_topic_id or ""),
|
||||
)
|
||||
return (await session.execute(stmt)).scalar_one_or_none()
|
||||
@@ -14,10 +14,26 @@ its storage implementation lives in ``deerflow.runtime.events.store.db`` and
|
||||
there is no matching entity directory.
|
||||
"""
|
||||
|
||||
from deerflow.persistence.channel_connections.model import (
|
||||
ChannelConnectionRow,
|
||||
ChannelConversationRow,
|
||||
ChannelCredentialRow,
|
||||
ChannelOAuthStateRow,
|
||||
)
|
||||
from deerflow.persistence.feedback.model import FeedbackRow
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
from deerflow.persistence.run.model import RunRow
|
||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
|
||||
__all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow", "UserRow"]
|
||||
__all__ = [
|
||||
"ChannelConnectionRow",
|
||||
"ChannelConversationRow",
|
||||
"ChannelCredentialRow",
|
||||
"ChannelOAuthStateRow",
|
||||
"FeedbackRow",
|
||||
"RunEventRow",
|
||||
"RunRow",
|
||||
"ThreadMetaRow",
|
||||
"UserRow",
|
||||
]
|
||||
|
||||
@@ -71,6 +71,15 @@ class ThreadMetaStore(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
"""Move a thread metadata row to a new owner.
|
||||
|
||||
Intended for trusted internal repair/migration paths. No-op if the
|
||||
row does not exist or the caller fails the owner check.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
||||
"""Check if ``user_id`` has access to ``thread_id``."""
|
||||
|
||||
@@ -127,6 +127,14 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
record["updated_at"] = now_iso()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_owner")
|
||||
if record is None:
|
||||
return
|
||||
record["user_id"] = owner_user_id
|
||||
record["updated_at"] = now_iso()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete")
|
||||
if record is None:
|
||||
|
||||
@@ -211,6 +211,21 @@ class ThreadMetaRepository(ThreadMetaStore):
|
||||
row.updated_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
|
||||
async def update_owner(
|
||||
self,
|
||||
thread_id: str,
|
||||
owner_user_id: str,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
"""Move a thread metadata row to ``owner_user_id``."""
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_owner")
|
||||
async with self._sf() as session:
|
||||
if not await self._check_ownership(session, thread_id, resolved_user_id):
|
||||
return
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(user_id=owner_user_id, updated_at=datetime.now(UTC)))
|
||||
await session.commit()
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
thread_id: str,
|
||||
|
||||
@@ -7,7 +7,7 @@ directly from ``deerflow.runtime``.
|
||||
|
||||
from .checkpointer import checkpointer_context, get_checkpointer, make_checkpointer, reset_checkpointer
|
||||
from .runs import ConflictError, DisconnectMode, RunContext, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
|
||||
from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple
|
||||
from .serialization import serialize, serialize_channel_values, serialize_channel_values_for_api, serialize_lc_object, serialize_messages_tuple, strip_data_url_image_blocks
|
||||
from .store import get_store, make_store, reset_store, store_context
|
||||
from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge
|
||||
|
||||
@@ -29,8 +29,10 @@ __all__ = [
|
||||
# serialization
|
||||
"serialize",
|
||||
"serialize_channel_values",
|
||||
"serialize_channel_values_for_api",
|
||||
"serialize_lc_object",
|
||||
"serialize_messages_tuple",
|
||||
"strip_data_url_image_blocks",
|
||||
# store
|
||||
"get_store",
|
||||
"make_store",
|
||||
|
||||
@@ -6,6 +6,7 @@ since all mutations happen within the same event loop).
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import bisect
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
@@ -13,7 +14,11 @@ from deerflow.runtime.events.store.base import RunEventStore
|
||||
|
||||
class MemoryRunEventStore(RunEventStore):
|
||||
def __init__(self) -> None:
|
||||
self._events: dict[str, list[dict]] = {} # thread_id -> sorted event list
|
||||
self._events: dict[str, list[dict]] = {} # thread_id -> seq-sorted event list
|
||||
# Messages-only projection of ``_events`` (same dict objects, no copies),
|
||||
# kept in seq order so message pagination is O(log m + page) via bisect
|
||||
# instead of re-scanning every event on each request.
|
||||
self._messages: dict[str, list[dict]] = {} # thread_id -> seq-sorted message list
|
||||
self._seq_counters: dict[str, int] = {} # thread_id -> last assigned seq
|
||||
|
||||
def _next_seq(self, thread_id: str) -> int:
|
||||
@@ -45,6 +50,8 @@ class MemoryRunEventStore(RunEventStore):
|
||||
"created_at": created_at or datetime.now(UTC).isoformat(),
|
||||
}
|
||||
self._events.setdefault(thread_id, []).append(record)
|
||||
if category == "message":
|
||||
self._messages.setdefault(thread_id, []).append(record)
|
||||
return record
|
||||
|
||||
async def put(
|
||||
@@ -76,18 +83,20 @@ class MemoryRunEventStore(RunEventStore):
|
||||
return results
|
||||
|
||||
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
messages = [e for e in all_events if e["category"] == "message"]
|
||||
# ``messages`` is messages-only and seq-sorted, so the seq window is a
|
||||
# contiguous slice located with bisect (O(log m)) rather than a full scan.
|
||||
messages = self._messages.get(thread_id, [])
|
||||
|
||||
if before_seq is not None:
|
||||
messages = [e for e in messages if e["seq"] < before_seq]
|
||||
# Take the last `limit` records
|
||||
return messages[-limit:]
|
||||
# Records with seq < before_seq, then the last `limit` of them.
|
||||
hi = bisect.bisect_left(messages, before_seq, key=lambda e: e["seq"])
|
||||
return messages[max(0, hi - limit) : hi]
|
||||
elif after_seq is not None:
|
||||
messages = [e for e in messages if e["seq"] > after_seq]
|
||||
return messages[:limit]
|
||||
# Records with seq > after_seq, then the first `limit` of them.
|
||||
lo = bisect.bisect_right(messages, after_seq, key=lambda e: e["seq"])
|
||||
return messages[lo : lo + limit]
|
||||
else:
|
||||
# Return the latest `limit` records, ascending
|
||||
# Return the latest `limit` records, ascending.
|
||||
return messages[-limit:]
|
||||
|
||||
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
||||
@@ -110,11 +119,11 @@ class MemoryRunEventStore(RunEventStore):
|
||||
return filtered[-limit:] if len(filtered) > limit else filtered
|
||||
|
||||
async def count_messages(self, thread_id):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
return sum(1 for e in all_events if e["category"] == "message")
|
||||
return len(self._messages.get(thread_id, []))
|
||||
|
||||
async def delete_by_thread(self, thread_id):
|
||||
events = self._events.pop(thread_id, [])
|
||||
self._messages.pop(thread_id, None)
|
||||
self._seq_counters.pop(thread_id, None)
|
||||
return len(events)
|
||||
|
||||
@@ -125,4 +134,6 @@ class MemoryRunEventStore(RunEventStore):
|
||||
remaining = [e for e in all_events if e["run_id"] != run_id]
|
||||
removed = len(all_events) - len(remaining)
|
||||
self._events[thread_id] = remaining
|
||||
# Keep the message projection in lockstep (same surviving dict objects).
|
||||
self._messages[thread_id] = [e for e in remaining if e["category"] == "message"]
|
||||
return removed
|
||||
|
||||
@@ -164,7 +164,18 @@ class RunJournal(BaseCallbackHandler):
|
||||
metadata={"caller": caller, **(metadata or {})},
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: Any,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Nested chain ends fire for internal graph nodes; only the root chain
|
||||
# represents the user-visible run lifecycle.
|
||||
if parent_run_id is not None:
|
||||
return
|
||||
self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"})
|
||||
self._flush_sync()
|
||||
|
||||
|
||||
@@ -83,6 +83,7 @@ class RunRecord:
|
||||
multitask_strategy: str = "reject"
|
||||
metadata: dict = field(default_factory=dict)
|
||||
kwargs: dict = field(default_factory=dict)
|
||||
user_id: str | None = None
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
task: asyncio.Task | None = field(default=None, repr=False)
|
||||
@@ -118,13 +119,48 @@ class RunManager:
|
||||
persistence_retry_policy: PersistenceRetryPolicy | None = None,
|
||||
) -> None:
|
||||
self._runs: dict[str, RunRecord] = {}
|
||||
# Secondary index: thread_id -> insertion-ordered run_id set (a dict is
|
||||
# used as an ordered set), maintained in lockstep with ``_runs`` so
|
||||
# per-thread queries avoid O(total in-memory runs) full scans while
|
||||
# preserving ``_runs`` iteration order (see ``_thread_records_locked``).
|
||||
self._runs_by_thread: dict[str, dict[str, None]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._store = store
|
||||
self._persistence_retry_policy = persistence_retry_policy or PersistenceRetryPolicy()
|
||||
|
||||
def _index_run_locked(self, record: RunRecord) -> None:
|
||||
"""Register *record* in the thread index. Caller must hold ``self._lock``."""
|
||||
self._runs_by_thread.setdefault(record.thread_id, {})[record.run_id] = None
|
||||
|
||||
def _unindex_run_locked(self, run_id: str, thread_id: str) -> None:
|
||||
"""Drop *run_id* from the thread index. Caller must hold ``self._lock``."""
|
||||
bucket = self._runs_by_thread.get(thread_id)
|
||||
if bucket is not None:
|
||||
bucket.pop(run_id, None)
|
||||
if not bucket:
|
||||
self._runs_by_thread.pop(thread_id, None)
|
||||
|
||||
def _thread_records_locked(self, thread_id: str) -> list[RunRecord]:
|
||||
"""Return live in-memory records for *thread_id*. Caller must hold ``self._lock``.
|
||||
|
||||
Uses the ``_runs_by_thread`` index for O(runs-in-thread) lookup instead of
|
||||
scanning every in-memory run. Correctness rests on the index and ``_runs``
|
||||
being mutated in lockstep under ``self._lock`` (no ``await`` between the two
|
||||
writes), so any holder of the lock sees them agree. The ``self._runs.get``
|
||||
filter is defense-in-depth, not reconciliation: it drops a stale id still in
|
||||
the index but already gone from ``_runs``, yet it cannot recover a run that is
|
||||
in ``_runs`` but missing from the index (such a run would be silently
|
||||
omitted). It guards only that one direction, should a future refactor ever
|
||||
break the lockstep invariant.
|
||||
"""
|
||||
run_ids = self._runs_by_thread.get(thread_id)
|
||||
if not run_ids:
|
||||
return []
|
||||
return [record for run_id in run_ids if (record := self._runs.get(run_id)) is not None]
|
||||
|
||||
@staticmethod
|
||||
def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]:
|
||||
return {
|
||||
payload = {
|
||||
"thread_id": record.thread_id,
|
||||
"assistant_id": record.assistant_id,
|
||||
"status": record.status.value,
|
||||
@@ -135,6 +171,9 @@ class RunManager:
|
||||
"created_at": record.created_at,
|
||||
"model_name": record.model_name,
|
||||
}
|
||||
if record.user_id is not None:
|
||||
payload["user_id"] = record.user_id
|
||||
return payload
|
||||
|
||||
async def _call_store_with_retry(
|
||||
self,
|
||||
@@ -241,6 +280,7 @@ class RunManager:
|
||||
kwargs=row.get("kwargs") or {},
|
||||
created_at=row.get("created_at") or "",
|
||||
updated_at=row.get("updated_at") or "",
|
||||
user_id=row.get("user_id"),
|
||||
error=row.get("error"),
|
||||
model_name=row.get("model_name"),
|
||||
store_only=True,
|
||||
@@ -320,6 +360,7 @@ class RunManager:
|
||||
metadata: dict | None = None,
|
||||
kwargs: dict | None = None,
|
||||
multitask_strategy: str = "reject",
|
||||
user_id: str | None = None,
|
||||
) -> RunRecord:
|
||||
"""Create a new pending run and register it."""
|
||||
run_id = str(uuid.uuid4())
|
||||
@@ -333,11 +374,13 @@ class RunManager:
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata=metadata or {},
|
||||
kwargs=kwargs or {},
|
||||
user_id=user_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
async with self._lock:
|
||||
self._runs[run_id] = record
|
||||
self._index_run_locked(record)
|
||||
persisted = False
|
||||
try:
|
||||
await self._persist_new_run_to_store(record)
|
||||
@@ -349,6 +392,7 @@ class RunManager:
|
||||
# Also covers cancellation, which bypasses ``except Exception``.
|
||||
if not persisted:
|
||||
self._runs.pop(run_id, None)
|
||||
self._unindex_run_locked(run_id, record.thread_id)
|
||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||
return record
|
||||
|
||||
@@ -404,8 +448,7 @@ class RunManager:
|
||||
limit: Maximum number of runs to return.
|
||||
"""
|
||||
async with self._lock:
|
||||
# Dict insertion order gives deterministic results when timestamps tie.
|
||||
memory_records = [r for r in self._runs.values() if r.thread_id == thread_id]
|
||||
memory_records = self._thread_records_locked(thread_id)
|
||||
if self._store is None:
|
||||
return sorted(memory_records, key=lambda r: r.created_at, reverse=True)[:limit]
|
||||
records_by_id = {record.run_id: record for record in memory_records}
|
||||
@@ -504,6 +547,7 @@ class RunManager:
|
||||
kwargs: dict | None = None,
|
||||
multitask_strategy: str = "reject",
|
||||
model_name: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> RunRecord:
|
||||
"""Atomically check for inflight runs and create a new one.
|
||||
|
||||
@@ -524,7 +568,7 @@ class RunManager:
|
||||
if multitask_strategy not in _supported_strategies:
|
||||
raise UnsupportedStrategyError(f"Multitask strategy '{multitask_strategy}' is not yet supported. Supported strategies: {', '.join(_supported_strategies)}")
|
||||
|
||||
inflight = [r for r in self._runs.values() if r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running)]
|
||||
inflight = [r for r in self._thread_records_locked(thread_id) if r.status in (RunStatus.pending, RunStatus.running)]
|
||||
|
||||
if multitask_strategy == "reject" and inflight:
|
||||
raise ConflictError(f"Thread {thread_id} already has an active run")
|
||||
@@ -546,11 +590,13 @@ class RunManager:
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata=metadata or {},
|
||||
kwargs=kwargs or {},
|
||||
user_id=user_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
model_name=model_name,
|
||||
)
|
||||
self._runs[run_id] = record
|
||||
self._index_run_locked(record)
|
||||
persisted = False
|
||||
try:
|
||||
await self._persist_new_run_to_store(record)
|
||||
@@ -562,6 +608,7 @@ class RunManager:
|
||||
# Also covers cancellation, which bypasses ``except Exception``.
|
||||
if not persisted:
|
||||
self._runs.pop(run_id, None)
|
||||
self._unindex_run_locked(run_id, record.thread_id)
|
||||
|
||||
if multitask_strategy in ("interrupt", "rollback") and inflight:
|
||||
for r in inflight:
|
||||
@@ -635,14 +682,16 @@ class RunManager:
|
||||
async def has_inflight(self, thread_id: str) -> bool:
|
||||
"""Return ``True`` if *thread_id* has a pending or running run."""
|
||||
async with self._lock:
|
||||
return any(r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running) for r in self._runs.values())
|
||||
return any(r.status in (RunStatus.pending, RunStatus.running) for r in self._thread_records_locked(thread_id))
|
||||
|
||||
async def cleanup(self, run_id: str, *, delay: float = 300) -> None:
|
||||
"""Remove a run record after an optional delay."""
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
async with self._lock:
|
||||
self._runs.pop(run_id, None)
|
||||
record = self._runs.pop(run_id, None)
|
||||
if record is not None:
|
||||
self._unindex_run_locked(run_id, record.thread_id)
|
||||
logger.debug("Run record %s cleaned up", run_id)
|
||||
|
||||
async def shutdown(self, *, timeout: float = 5.0) -> None:
|
||||
|
||||
@@ -56,6 +56,56 @@ def serialize_channel_values(channel_values: dict[str, Any]) -> dict[str, Any]:
|
||||
return result
|
||||
|
||||
|
||||
def strip_data_url_image_blocks(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Remove ``data:``-scheme ``image_url`` blocks from *hide_from_ui* messages.
|
||||
|
||||
The history and run-wait endpoints return checkpoint-persisted messages to
|
||||
the frontend. ``ViewImageMiddleware`` stores full base64 image payloads in
|
||||
``hide_from_ui`` human messages — these are internal model context and must
|
||||
not be sent over the wire (huge response bodies, no UI value).
|
||||
|
||||
Only content blocks of type ``image_url`` whose URL starts with ``data:``
|
||||
are stripped. Text blocks, ``https://`` image URLs, and non-hidden
|
||||
messages are left untouched so that message ordering and count are
|
||||
preserved.
|
||||
"""
|
||||
result: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
result.append(msg)
|
||||
continue
|
||||
|
||||
# Only touch messages explicitly flagged as hidden from the UI.
|
||||
additional_kwargs = msg.get("additional_kwargs")
|
||||
if not (isinstance(additional_kwargs, dict) and additional_kwargs.get("hide_from_ui") is True):
|
||||
result.append(msg)
|
||||
continue
|
||||
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
result.append(msg)
|
||||
continue
|
||||
|
||||
# Filter out image_url blocks with data: scheme.
|
||||
filtered = [block for block in content if not (isinstance(block, dict) and block.get("type") == "image_url" and isinstance(block.get("image_url"), dict) and str(block["image_url"].get("url", "")).startswith("data:"))]
|
||||
result.append({**msg, "content": filtered})
|
||||
return result
|
||||
|
||||
|
||||
def serialize_channel_values_for_api(channel_values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Serialize channel values and strip base64 image data from messages.
|
||||
|
||||
Convenience wrapper combining :func:`serialize_channel_values` with
|
||||
:func:`strip_data_url_image_blocks`. Use this in all REST endpoints
|
||||
that return channel values to the frontend so that ``data:``-scheme
|
||||
base64 image payloads are never sent over the wire.
|
||||
"""
|
||||
result = serialize_channel_values(channel_values)
|
||||
if isinstance(result.get("messages"), list):
|
||||
result["messages"] = strip_data_url_image_blocks(result["messages"])
|
||||
return result
|
||||
|
||||
|
||||
def serialize_messages_tuple(obj: Any) -> Any:
|
||||
"""Serialize a messages-mode tuple ``(chunk, metadata)``."""
|
||||
if isinstance(obj, tuple) and len(obj) == 2:
|
||||
|
||||
@@ -147,7 +147,17 @@ class LocalSandboxProvider(SandboxProvider):
|
||||
mount.container_path,
|
||||
)
|
||||
continue
|
||||
# Ensure the host path exists before adding mapping
|
||||
# Ensure the host path exists before adding mapping.
|
||||
#
|
||||
# ``host_path`` is resolved against the filesystem of the
|
||||
# process running this provider — for ``make dev`` that is
|
||||
# the host machine, but for ``make up`` it is the
|
||||
# ``deer-flow-gateway`` container, so any host path that
|
||||
# isn't bind-mounted into the gateway image will be missing
|
||||
# here. Skipping silently makes this a high-cost-to-debug
|
||||
# silent failure (sandbox skill / tool reads an empty dir
|
||||
# instead of the configured mount), so escalate to ERROR
|
||||
# and include actionable guidance. See #3244.
|
||||
if host_path.exists():
|
||||
mappings.append(
|
||||
PathMapping(
|
||||
@@ -157,10 +167,16 @@ class LocalSandboxProvider(SandboxProvider):
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Mount host_path does not exist, skipping: %s -> %s",
|
||||
logger.error(
|
||||
"sandbox.mounts entry %s -> %s ignored: host_path %s does not exist from the "
|
||||
"perspective of the gateway process. In Docker deployments (make up / docker-compose), "
|
||||
"this path must also be bind-mounted into the gateway container — add a matching "
|
||||
"volume entry under services.gateway.volumes in docker/docker-compose.yaml (and use "
|
||||
"the in-container path here), or run in local mode (make dev) where the gateway sees "
|
||||
"the host filesystem directly.",
|
||||
mount.host_path,
|
||||
mount.container_path,
|
||||
mount.host_path,
|
||||
)
|
||||
except Exception as e:
|
||||
# Log but don't fail if config loading fails
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import replace as dc_replace
|
||||
from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Command
|
||||
|
||||
from deerflow.agents.thread_state import SandboxState, ThreadDataState
|
||||
from deerflow.sandbox import get_sandbox_provider
|
||||
@@ -126,3 +131,87 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
||||
|
||||
# No sandbox to release
|
||||
return await super().aafter_agent(state, runtime)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool-call wrappers: persist lazily-acquired sandbox state into the
|
||||
# graph state via Command(update=...).
|
||||
#
|
||||
# Background:
|
||||
# ``ensure_sandbox_initialized*`` in ``deerflow.sandbox.tools`` mutates
|
||||
# ``runtime.state["sandbox"]`` directly. That mutation is local to the
|
||||
# current tool invocation and is NOT picked up by LangGraph's channel
|
||||
# reducer, so subsequent graph steps (and downstream consumers such as
|
||||
# ``ToolOutputBudgetMiddleware`` and the sub-agent ``task_tool``)
|
||||
# cannot observe the sandbox id. Wrapping the tool call lets us detect
|
||||
# a fresh lazy init by diffing the state snapshot before/after the
|
||||
# handler and emit a proper state update via ``Command``.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _read_sandbox_id_from_state(state: object) -> str | None:
|
||||
if not isinstance(state, dict):
|
||||
return None
|
||||
sandbox_state = state.get("sandbox")
|
||||
if not isinstance(sandbox_state, dict):
|
||||
return None
|
||||
sandbox_id = sandbox_state.get("sandbox_id")
|
||||
return sandbox_id if isinstance(sandbox_id, str) else None
|
||||
|
||||
@staticmethod
|
||||
def _attach_sandbox_update(result: ToolMessage | Command, sandbox_id: str) -> ToolMessage | Command:
|
||||
"""Wrap or merge ``result`` so that ``sandbox.sandbox_id`` is persisted.
|
||||
|
||||
- ``ToolMessage`` -> ``Command(update={"sandbox": ..., "messages": [msg]})``
|
||||
- ``Command`` with dict update -> merge ``sandbox`` key, preserve all
|
||||
existing fields (``messages``, ``goto``, ``graph``, ``resume``, ...).
|
||||
- ``Command`` with non-dict / None update -> leave it untouched to
|
||||
avoid silent data loss on unknown update shapes.
|
||||
"""
|
||||
sandbox_update = {"sandbox": {"sandbox_id": sandbox_id}}
|
||||
|
||||
if isinstance(result, ToolMessage):
|
||||
return Command(update={**sandbox_update, "messages": [result]})
|
||||
|
||||
existing_update = result.update
|
||||
if isinstance(existing_update, dict):
|
||||
merged_update = {**existing_update, **sandbox_update}
|
||||
return dc_replace(result, update=merged_update)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _read_sandbox_id_from_request(request: ToolCallRequest) -> str | None:
|
||||
"""Read sandbox_id from runtime.state (where ensure_sandbox_initialized writes)."""
|
||||
runtime = request.runtime
|
||||
if runtime is None or runtime.state is None:
|
||||
return None
|
||||
return SandboxMiddleware._read_sandbox_id_from_state(runtime.state)
|
||||
|
||||
@override
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
prev_sandbox_id = self._read_sandbox_id_from_request(request)
|
||||
result = handler(request)
|
||||
if prev_sandbox_id is not None:
|
||||
return result
|
||||
curr_sandbox_id = self._read_sandbox_id_from_request(request)
|
||||
if curr_sandbox_id is None:
|
||||
return result
|
||||
return self._attach_sandbox_update(result, curr_sandbox_id)
|
||||
|
||||
@override
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
prev_sandbox_id = self._read_sandbox_id_from_request(request)
|
||||
result = await handler(request)
|
||||
if prev_sandbox_id is not None:
|
||||
return result
|
||||
curr_sandbox_id = self._read_sandbox_id_from_request(request)
|
||||
if curr_sandbox_id is None:
|
||||
return result
|
||||
return self._attach_sandbox_update(result, curr_sandbox_id)
|
||||
|
||||
@@ -153,7 +153,7 @@ async def _scan_skill_file_or_raise(skill_dir: Path, path: Path, skill_name: str
|
||||
rel_path = path.relative_to(skill_dir).as_posix()
|
||||
location = f"{skill_name}/{rel_path}"
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
content = await asyncio.to_thread(path.read_text, encoding="utf-8")
|
||||
except UnicodeDecodeError as e:
|
||||
raise SkillSecurityScanError(f"Security scan failed for skill '{skill_name}': {location} must be valid UTF-8") from e
|
||||
|
||||
@@ -174,15 +174,17 @@ async def _scan_skill_file_or_raise(skill_dir: Path, path: Path, skill_name: str
|
||||
raise SkillSecurityScanError(f"Security scan failed for {location}: invalid scanner decision {decision!r}")
|
||||
|
||||
|
||||
def _collect_scannable_files(skill_dir: Path) -> list[Path]:
|
||||
"""Enumerate archive files for scanning (blocking; run off the event loop)."""
|
||||
return [candidate for candidate in sorted(skill_dir.rglob("*")) if candidate.is_file()]
|
||||
|
||||
|
||||
async def _scan_skill_archive_contents_or_raise(skill_dir: Path, skill_name: str) -> None:
|
||||
"""Run the skill security scanner against all installable text and script files."""
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
await _scan_skill_file_or_raise(skill_dir, skill_md, skill_name, executable=False)
|
||||
|
||||
for path in sorted(skill_dir.rglob("*")):
|
||||
if not path.is_file():
|
||||
continue
|
||||
|
||||
for path in await asyncio.to_thread(_collect_scannable_files, skill_dir):
|
||||
rel_path = path.relative_to(skill_dir)
|
||||
if rel_path == Path("SKILL.md"):
|
||||
continue
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import errno
|
||||
import json
|
||||
import logging
|
||||
@@ -21,6 +22,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SKILLS_CONTAINER_PATH = "/mnt/skills"
|
||||
|
||||
# Bound for the best-effort temp-dir cleanup so a stalled filesystem (e.g. NFS)
|
||||
# cannot hold back the install outcome propagating out of the finally block.
|
||||
_INSTALL_TMP_CLEANUP_TIMEOUT_SECONDS = 5.0
|
||||
|
||||
|
||||
class LocalSkillStorage(SkillStorage):
|
||||
"""Skill storage backed by the local filesystem.
|
||||
@@ -94,19 +99,56 @@ class LocalSkillStorage(SkillStorage):
|
||||
make_skill_written_path_sandbox_readable(self.get_custom_skill_dir(name), target)
|
||||
|
||||
async def ainstall_skill_from_archive(self, archive_path: str | Path) -> dict:
|
||||
from deerflow.skills.installer import _scan_skill_archive_contents_or_raise
|
||||
|
||||
logger.info("Installing skill from %s", archive_path)
|
||||
path = Path(archive_path)
|
||||
custom_dir = self._host_root / "custom"
|
||||
|
||||
# The per-file security scan is an async LLM call and must stay on the
|
||||
# event loop; every filesystem phase around it runs in a worker thread.
|
||||
tmp = await asyncio.to_thread(tempfile.mkdtemp)
|
||||
try:
|
||||
skill_dir, skill_name, target = await asyncio.to_thread(self._prepare_skill_archive, path, Path(tmp), custom_dir, archive_path)
|
||||
|
||||
await _scan_skill_archive_contents_or_raise(skill_dir, skill_name)
|
||||
|
||||
await asyncio.to_thread(self._commit_skill_install, skill_dir, skill_name, custom_dir, target)
|
||||
logger.info("Skill %r installed to %s", skill_name, target)
|
||||
finally:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.to_thread(self._cleanup_install_tmp, tmp),
|
||||
timeout=_INSTALL_TMP_CLEANUP_TIMEOUT_SECONDS,
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning("Timed out cleaning up skill install temp dir %s", tmp)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"skill_name": skill_name,
|
||||
"message": f"Skill '{skill_name}' installed successfully",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _cleanup_install_tmp(tmp: str) -> None:
|
||||
"""Best-effort removal that never masks the install outcome, but leaves a trace."""
|
||||
try:
|
||||
shutil.rmtree(tmp)
|
||||
except OSError:
|
||||
logger.warning("Failed to clean up skill install temp dir %s", tmp, exc_info=True)
|
||||
|
||||
def _prepare_skill_archive(self, path: Path, tmp_path: Path, custom_dir: Path, archive_path: str | Path) -> tuple[Path, str, Path]:
|
||||
"""Extract and validate the archive (blocking; runs off the event loop)."""
|
||||
import zipfile
|
||||
|
||||
from deerflow.skills.installer import (
|
||||
SkillAlreadyExistsError,
|
||||
_move_staged_skill_into_reserved_target,
|
||||
_scan_skill_archive_contents_or_raise,
|
||||
resolve_skill_dir_from_archive,
|
||||
safe_extract_skill_archive,
|
||||
)
|
||||
from deerflow.skills.validation import _validate_skill_frontmatter
|
||||
|
||||
logger.info("Installing skill from %s", archive_path)
|
||||
path = Path(archive_path)
|
||||
if not path.is_file():
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Skill file not found: {archive_path}")
|
||||
@@ -114,47 +156,40 @@ class LocalSkillStorage(SkillStorage):
|
||||
if path.suffix != ".skill":
|
||||
raise ValueError("File must have .skill extension")
|
||||
|
||||
custom_dir = self._host_root / "custom"
|
||||
custom_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_path = Path(tmp)
|
||||
try:
|
||||
zf = zipfile.ZipFile(path, "r")
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Skill file not found: {archive_path}") from None
|
||||
except (zipfile.BadZipFile, IsADirectoryError):
|
||||
raise ValueError("File is not a valid ZIP archive") from None
|
||||
|
||||
try:
|
||||
zf = zipfile.ZipFile(path, "r")
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"Skill file not found: {archive_path}") from None
|
||||
except (zipfile.BadZipFile, IsADirectoryError):
|
||||
raise ValueError("File is not a valid ZIP archive") from None
|
||||
with zf:
|
||||
safe_extract_skill_archive(zf, tmp_path)
|
||||
|
||||
with zf:
|
||||
safe_extract_skill_archive(zf, tmp_path)
|
||||
skill_dir = resolve_skill_dir_from_archive(tmp_path)
|
||||
|
||||
skill_dir = resolve_skill_dir_from_archive(tmp_path)
|
||||
is_valid, message, skill_name = _validate_skill_frontmatter(skill_dir)
|
||||
if not is_valid:
|
||||
raise ValueError(f"Invalid skill: {message}")
|
||||
if not skill_name or "/" in skill_name or "\\" in skill_name or ".." in skill_name:
|
||||
raise ValueError(f"Invalid skill name: {skill_name}")
|
||||
|
||||
is_valid, message, skill_name = _validate_skill_frontmatter(skill_dir)
|
||||
if not is_valid:
|
||||
raise ValueError(f"Invalid skill: {message}")
|
||||
if not skill_name or "/" in skill_name or "\\" in skill_name or ".." in skill_name:
|
||||
raise ValueError(f"Invalid skill name: {skill_name}")
|
||||
target = custom_dir / skill_name
|
||||
if target.exists():
|
||||
raise SkillAlreadyExistsError(f"Skill '{skill_name}' already exists")
|
||||
|
||||
target = custom_dir / skill_name
|
||||
if target.exists():
|
||||
raise SkillAlreadyExistsError(f"Skill '{skill_name}' already exists")
|
||||
return skill_dir, skill_name, target
|
||||
|
||||
await _scan_skill_archive_contents_or_raise(skill_dir, skill_name)
|
||||
def _commit_skill_install(self, skill_dir: Path, skill_name: str, custom_dir: Path, target: Path) -> None:
|
||||
"""Stage and move the validated skill into place (blocking; runs off the event loop)."""
|
||||
from deerflow.skills.installer import _move_staged_skill_into_reserved_target
|
||||
|
||||
with tempfile.TemporaryDirectory(prefix=f".installing-{skill_name}-", dir=custom_dir) as staging_root:
|
||||
staging_target = Path(staging_root) / skill_name
|
||||
shutil.copytree(skill_dir, staging_target)
|
||||
_move_staged_skill_into_reserved_target(staging_target, target)
|
||||
logger.info("Skill %r installed to %s", skill_name, target)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"skill_name": skill_name,
|
||||
"message": f"Skill '{skill_name}' installed successfully",
|
||||
}
|
||||
with tempfile.TemporaryDirectory(prefix=f".installing-{skill_name}-", dir=custom_dir) as staging_root:
|
||||
staging_target = Path(staging_root) / skill_name
|
||||
shutil.copytree(skill_dir, staging_target)
|
||||
_move_staged_skill_into_reserved_target(staging_target, target)
|
||||
|
||||
def delete_custom_skill(self, name: str, *, history_meta: dict | None = None) -> None:
|
||||
self.validate_skill_name(name)
|
||||
|
||||
@@ -351,6 +351,7 @@ class SubagentExecutor:
|
||||
middleware=middlewares,
|
||||
system_prompt=None,
|
||||
state_schema=ThreadState,
|
||||
checkpointer=False,
|
||||
)
|
||||
|
||||
async def _load_skills(self) -> list[Skill]:
|
||||
|
||||
@@ -28,6 +28,25 @@ def setup_agent(
|
||||
skills: Optional list of skill names this agent should use. None means use all enabled skills, empty list means no skills.
|
||||
"""
|
||||
|
||||
# Reject empty / whitespace-only soul before touching the filesystem.
|
||||
# Without this guard the tool would happily persist an empty SOUL.md and
|
||||
# still report success, which caused the frontend to enter the "agent
|
||||
# created" state for an unusable agent (issue #3549). Failing loud lets
|
||||
# the model retry instead of silently producing a broken artifact and,
|
||||
# together with the upstream agent_name fix, prevents the global default
|
||||
# SOUL.md from being overwritten with empty content.
|
||||
if not soul or not soul.strip():
|
||||
return Command(
|
||||
update={
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content="Error: soul content is empty; refusing to create agent with an empty SOUL.md",
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
agent_name: str | None = runtime.context.get("agent_name") if runtime.context else None
|
||||
agent_dir = None
|
||||
is_new_dir = False
|
||||
|
||||
@@ -36,6 +36,7 @@ dependencies = [
|
||||
"sqlalchemy[asyncio]>=2.0,<3.0",
|
||||
"aiosqlite>=0.19",
|
||||
"alembic>=1.13",
|
||||
"cryptography>=43.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
"""Regression anchors: channel runtime-config handlers must not block the event loop.
|
||||
|
||||
``configure_channel_provider_runtime`` and ``disconnect_channel_provider_runtime``
|
||||
persist UI-entered channel credentials through ``ChannelRuntimeConfigStore``,
|
||||
whose construction reads its JSON file and whose setters rewrite it
|
||||
(``json.dump`` + ``Path.replace`` + ``chmod``). The handlers offload both via
|
||||
``asyncio.to_thread``; if that regresses back onto the event loop, the strict
|
||||
Blockbuster gate raises ``BlockingError`` and these tests fail.
|
||||
|
||||
The handlers are invoked directly with a minimal Starlette ``Request`` so the
|
||||
surface under test is exactly the router's own IO, mirroring
|
||||
``test_agents_router``. Test-side seeding/inspection is offloaded with
|
||||
``asyncio.to_thread``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
from types import SimpleNamespace
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
|
||||
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
|
||||
from app.gateway.routers.channel_connections import (
|
||||
ChannelRuntimeConfigRequest,
|
||||
configure_channel_provider_runtime,
|
||||
disconnect_channel_provider_runtime,
|
||||
)
|
||||
from deerflow.config.app_config import AppConfig, reset_app_config, set_app_config
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
|
||||
# Pre-import: the handlers import this module lazily; the import's file IO
|
||||
# must happen at collection time, not on the event loop under the gate.
|
||||
importlib.import_module("app.channels.service")
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_app_config():
|
||||
set_app_config(AppConfig.model_validate({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}))
|
||||
yield
|
||||
reset_app_config()
|
||||
|
||||
|
||||
def _make_request(tmp_path) -> Request:
|
||||
app = FastAPI()
|
||||
app.state.channel_connections_config = ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"slack": {"enabled": True},
|
||||
}
|
||||
)
|
||||
app.state.channels_config = {}
|
||||
app.state.channel_connection_repo = _FakeRepo()
|
||||
store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
|
||||
app.state.channel_runtime_config_store = store
|
||||
user = SimpleNamespace(id=UUID("11111111-2222-3333-4444-555555555555"), system_role="admin")
|
||||
return Request({"type": "http", "app": app, "headers": [], "state": {"user": user}})
|
||||
|
||||
|
||||
class _FakeRepo:
|
||||
async def list_connections(self, owner_user_id):
|
||||
return []
|
||||
|
||||
|
||||
async def test_configure_runtime_channel_does_not_block_event_loop(tmp_path) -> None:
|
||||
request = await asyncio.to_thread(_make_request, tmp_path)
|
||||
|
||||
response = await configure_channel_provider_runtime(
|
||||
"slack",
|
||||
ChannelRuntimeConfigRequest(values={"bot_token": "xoxb-ui", "app_token": "xapp-ui"}),
|
||||
request,
|
||||
)
|
||||
|
||||
assert response.provider == "slack"
|
||||
store = request.app.state.channel_runtime_config_store
|
||||
assert await asyncio.to_thread(store.get_provider_config, "slack") == {
|
||||
"enabled": True,
|
||||
"bot_token": "xoxb-ui",
|
||||
"app_token": "xapp-ui",
|
||||
}
|
||||
|
||||
|
||||
async def test_disconnect_runtime_channel_does_not_block_event_loop(tmp_path) -> None:
|
||||
request = await asyncio.to_thread(_make_request, tmp_path)
|
||||
store = request.app.state.channel_runtime_config_store
|
||||
await asyncio.to_thread(
|
||||
store.set_provider_config,
|
||||
"slack",
|
||||
{"enabled": True, "bot_token": "xoxb-ui", "app_token": "xapp-ui"},
|
||||
)
|
||||
request.app.state.channels_config = {
|
||||
"slack": {"enabled": True, "bot_token": "xoxb-ui", "app_token": "xapp-ui"},
|
||||
}
|
||||
|
||||
response = await disconnect_channel_provider_runtime("slack", request)
|
||||
|
||||
assert response.provider == "slack"
|
||||
assert await asyncio.to_thread(store.get_provider_config, "slack") == {
|
||||
"enabled": False,
|
||||
"_runtime_disabled": True,
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Regression anchor: ingesting inbound channel files must not block the event loop.
|
||||
|
||||
``ChannelManager``'s ``_ingest_inbound_files`` ensures the thread uploads
|
||||
directory (``mkdir``), enumerates it (``iterdir`` / ``is_file``) to de-duplicate
|
||||
filenames, and writes each downloaded attachment to disk
|
||||
(``write_upload_file_no_symlink``) — all blocking filesystem IO. The async
|
||||
function offloads the directory prep and every per-file write via
|
||||
``asyncio.to_thread`` while keeping the genuinely async network read
|
||||
(``file_reader``) on the loop. If any of that regresses back onto the event
|
||||
loop, the strict Blockbuster gate raises ``BlockingError`` and this test fails.
|
||||
|
||||
Imports are kept at module top so any import-time IO runs at collection (outside
|
||||
the gate); the surface under test runs on the event loop inside the gated test.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from app.channels import manager as mgr
|
||||
from app.channels.message_bus import InboundMessage
|
||||
from deerflow.uploads.manager import get_uploads_dir
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def test_ingest_inbound_files_does_not_block_event_loop(tmp_path: Path, monkeypatch) -> None:
|
||||
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
|
||||
# Rebuild the cached Paths against the tmp home so uploads resolve under it.
|
||||
import deerflow.config.paths as paths_mod
|
||||
|
||||
monkeypatch.setattr(paths_mod, "_paths", None)
|
||||
|
||||
# Swap the network reader for an in-memory one: no real HTTP, so the only IO
|
||||
# left for this anchor to guard is the filesystem work.
|
||||
async def _fake_reader(f, client):
|
||||
return b"payload-bytes"
|
||||
|
||||
monkeypatch.setattr(mgr, "_read_http_inbound_file", _fake_reader)
|
||||
|
||||
msg = InboundMessage(
|
||||
channel_name="unit-test-channel", # absent from INBOUND_FILE_READERS -> default reader
|
||||
chat_id="c1",
|
||||
user_id="u1",
|
||||
text="hi",
|
||||
files=[{"type": "file", "filename": "report.txt"}],
|
||||
)
|
||||
|
||||
created = await mgr._ingest_inbound_files("t1", msg)
|
||||
|
||||
assert len(created) == 1
|
||||
assert created[0]["filename"] == "report.txt"
|
||||
written = await asyncio.to_thread(lambda: (get_uploads_dir("t1") / "report.txt").exists())
|
||||
assert written, "inbound file should be written under the tmp uploads dir"
|
||||
@@ -0,0 +1,73 @@
|
||||
"""Regression anchor: skill archive installation must not block the event loop.
|
||||
|
||||
``LocalSkillStorage.ainstall_skill_from_archive`` is the async entry point the
|
||||
gateway ``POST /skills/install`` route awaits. It extracts the archive,
|
||||
validates frontmatter, security-scans every installable file, and stages the
|
||||
skill into the custom directory — all filesystem work that previously ran
|
||||
inline on the event loop (zip extract, ``rglob`` enumeration, ``read_text``,
|
||||
``shutil.copytree``). The fix offloads those phases via ``asyncio.to_thread``
|
||||
while keeping the per-file LLM security scan as the only awaited work; if any
|
||||
phase regresses back onto the loop, the strict Blockbuster gate raises
|
||||
``BlockingError`` and this test fails.
|
||||
|
||||
Only the external LLM boundary (``scan_skill_content``) is stubbed — the
|
||||
archive, extraction, validation, and staging all run against the real local
|
||||
filesystem. Test-side setup IO is itself offloaded with ``asyncio.to_thread``
|
||||
(matching ``test_agents_router``) so only the production path is exercised on
|
||||
the loop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.skills.storage.local_skill_storage import LocalSkillStorage
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
_SKILL_MD = """---
|
||||
name: loop-skill
|
||||
description: Anchor fixture skill for the blocking-IO gate.
|
||||
---
|
||||
|
||||
# Loop Skill
|
||||
|
||||
Drives the full install pipeline under the Blockbuster gate.
|
||||
"""
|
||||
|
||||
_SUPPORT_MD = "Reference notes scanned by the per-file security pass.\n"
|
||||
|
||||
|
||||
def _build_archive(archive: Path) -> None:
|
||||
with zipfile.ZipFile(archive, "w") as zf:
|
||||
zf.writestr("loop-skill/SKILL.md", _SKILL_MD)
|
||||
zf.writestr("loop-skill/references/usage.md", _SUPPORT_MD)
|
||||
|
||||
|
||||
async def test_install_skill_archive_does_not_block_event_loop(tmp_path: Path, monkeypatch) -> None:
|
||||
archive = tmp_path / "loop-skill.skill"
|
||||
await asyncio.to_thread(_build_archive, archive)
|
||||
|
||||
async def _allow_scan(content: str, *, executable: bool = False, location: str = "SKILL.md", app_config=None):
|
||||
return SimpleNamespace(decision="allow", reason="anchor stub")
|
||||
|
||||
# External dependency boundary only: the security scanner is an LLM call.
|
||||
monkeypatch.setattr("deerflow.skills.installer.scan_skill_content", _allow_scan)
|
||||
|
||||
# Constructor resolves paths (one-time, cached in production via
|
||||
# get_or_new_skill_storage); offloaded here so the anchor exercises only
|
||||
# the install pipeline itself on the loop.
|
||||
storage = await asyncio.to_thread(LocalSkillStorage, host_path=str(tmp_path / "skills"))
|
||||
|
||||
result = await storage.ainstall_skill_from_archive(archive)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["skill_name"] == "loop-skill"
|
||||
installed_md = tmp_path / "skills" / "custom" / "loop-skill" / "SKILL.md"
|
||||
assert await asyncio.to_thread(installed_md.exists)
|
||||
assert await asyncio.to_thread((tmp_path / "skills" / "custom" / "loop-skill" / "references" / "usage.md").exists)
|
||||
@@ -69,6 +69,7 @@
|
||||
"keys": [
|
||||
"artifacts",
|
||||
"messages",
|
||||
"sandbox",
|
||||
"thread_data",
|
||||
"title",
|
||||
"viewed_images"
|
||||
@@ -79,6 +80,7 @@
|
||||
"keys": [
|
||||
"artifacts",
|
||||
"messages",
|
||||
"sandbox",
|
||||
"thread_data",
|
||||
"title",
|
||||
"viewed_images"
|
||||
@@ -89,6 +91,7 @@
|
||||
"keys": [
|
||||
"artifacts",
|
||||
"messages",
|
||||
"sandbox",
|
||||
"thread_data",
|
||||
"title",
|
||||
"viewed_images"
|
||||
@@ -99,6 +102,7 @@
|
||||
"keys": [
|
||||
"artifacts",
|
||||
"messages",
|
||||
"sandbox",
|
||||
"thread_data",
|
||||
"title",
|
||||
"viewed_images"
|
||||
@@ -109,6 +113,7 @@
|
||||
"keys": [
|
||||
"artifacts",
|
||||
"messages",
|
||||
"sandbox",
|
||||
"thread_data",
|
||||
"title",
|
||||
"viewed_images"
|
||||
@@ -119,6 +124,7 @@
|
||||
"keys": [
|
||||
"artifacts",
|
||||
"messages",
|
||||
"sandbox",
|
||||
"thread_data",
|
||||
"title",
|
||||
"viewed_images"
|
||||
|
||||
@@ -0,0 +1,212 @@
|
||||
"""Intersect a git diff with static blocking-IO findings.
|
||||
|
||||
Wraps the static detector (`blocking_io_static`) to answer a narrower question:
|
||||
which blocking-IO candidates does THIS change introduce? A candidate qualifies
|
||||
when its blocking line is on an added line of the diff, or when the finding is
|
||||
new versus the merge base — the latter catches exposure created without
|
||||
touching the blocking line itself (a new async caller making an old sync
|
||||
helper async-reachable). Used by the `blocking-io-guard` skill as the
|
||||
deterministic scope step.
|
||||
|
||||
Not directly executable: import as `support.detectors.blocking_io_changed` or
|
||||
run via the CLI shim `scripts/scan_changed_blocking_io.py`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
|
||||
from support.detectors import blocking_io_static as static
|
||||
from support.detectors.repo_root import resolve_repo_root
|
||||
|
||||
REPO_ROOT = resolve_repo_root(Path(__file__))
|
||||
SCAN_ROOTS = (
|
||||
"backend/app",
|
||||
"backend/packages/harness/deerflow",
|
||||
"backend/scripts",
|
||||
)
|
||||
|
||||
_HUNK_RE = re.compile(r"^@@ -\d+(?:,\d+)? \+(\d+)(?:,\d+)? @@")
|
||||
|
||||
|
||||
def parse_changed_lines(diff_text: str) -> dict[str, set[int]]:
|
||||
"""Map repo-relative path -> set of added line numbers in the new file.
|
||||
|
||||
Accepts any unified diff (with or without `--unified=0`): context lines
|
||||
advance the new-file counter, deletions (`-`) and `\\ No newline` markers
|
||||
do not. Records only added lines (`+`, not the `+++` header), numbered
|
||||
from each hunk's new-file start line; deleted files (`+++ /dev/null`) are
|
||||
skipped.
|
||||
"""
|
||||
changed: dict[str, set[int]] = defaultdict(set)
|
||||
current_path: str | None = None
|
||||
next_line = 0
|
||||
for raw in diff_text.splitlines():
|
||||
if raw.startswith("+++ "):
|
||||
target = raw[4:].strip()
|
||||
if target == "/dev/null":
|
||||
current_path = None
|
||||
else:
|
||||
current_path = target[2:] if target.startswith("b/") else target
|
||||
next_line = 0
|
||||
continue
|
||||
match = _HUNK_RE.match(raw)
|
||||
if match:
|
||||
next_line = int(match.group(1))
|
||||
continue
|
||||
if not current_path:
|
||||
continue
|
||||
if raw.startswith("+"):
|
||||
changed[current_path].add(next_line)
|
||||
next_line += 1
|
||||
elif raw.startswith(" ") or raw == "":
|
||||
next_line += 1
|
||||
return dict(changed)
|
||||
|
||||
|
||||
def changed_python_lines(base: str, repo_root: Path = REPO_ROOT) -> dict[str, set[int]]:
|
||||
"""Diff `base...HEAD` over scan roots and return added .py lines."""
|
||||
cmd = [
|
||||
"git",
|
||||
"-C",
|
||||
str(repo_root),
|
||||
"diff",
|
||||
"--unified=0",
|
||||
"--no-color",
|
||||
f"{base}...HEAD",
|
||||
"--",
|
||||
*SCAN_ROOTS,
|
||||
]
|
||||
diff_text = subprocess.run(cmd, capture_output=True, text=True, check=True).stdout
|
||||
return {path: lines for path, lines in parse_changed_lines(diff_text).items() if path.endswith(".py")}
|
||||
|
||||
|
||||
def select_findings_on_changed_lines(
|
||||
findings: Sequence[dict[str, object]],
|
||||
changed_lines: dict[str, set[int]],
|
||||
) -> list[dict[str, object]]:
|
||||
"""Keep findings whose (path, line) falls on a changed line."""
|
||||
selected: list[dict[str, object]] = []
|
||||
for finding in findings:
|
||||
location = finding["location"] # type: ignore[index]
|
||||
path = location["path"] # type: ignore[index]
|
||||
line = location["line"] # type: ignore[index]
|
||||
if line in changed_lines.get(path, set()):
|
||||
selected.append(finding)
|
||||
return selected
|
||||
|
||||
|
||||
def base_python_contents(base: str, paths: Sequence[str], repo_root: Path = REPO_ROOT) -> dict[str, str]:
|
||||
"""Return each path's content at the merge base of `base` and HEAD.
|
||||
|
||||
Files absent at the merge base (newly added) are omitted, so every head
|
||||
finding in them counts as new.
|
||||
"""
|
||||
merge_base = subprocess.run(
|
||||
["git", "-C", str(repo_root), "merge-base", base, "HEAD"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
).stdout.strip()
|
||||
contents: dict[str, str] = {}
|
||||
for path in paths:
|
||||
shown = subprocess.run(
|
||||
["git", "-C", str(repo_root), "show", f"{merge_base}:{path}"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if shown.returncode == 0:
|
||||
contents[path] = shown.stdout
|
||||
return contents
|
||||
|
||||
|
||||
def scan_python_contents(contents: dict[str, str]) -> list[dict[str, object]]:
|
||||
"""Run the static detector over in-memory sources (repo-relative path -> code)."""
|
||||
findings: list[dict[str, object]] = []
|
||||
for rel_path in sorted(contents):
|
||||
findings.extend(finding.to_dict() for finding in static.scan_source(contents[rel_path], rel_path))
|
||||
return findings
|
||||
|
||||
|
||||
def _stable_key(finding: dict[str, object]) -> tuple[str, str, str]:
|
||||
location = finding["location"] # type: ignore[index]
|
||||
call = finding["blocking_call"] # type: ignore[index]
|
||||
return (location["path"], location["function"], call["symbol"]) # type: ignore[index]
|
||||
|
||||
|
||||
def select_findings_new_vs_base(
|
||||
head_findings: Sequence[dict[str, object]],
|
||||
base_findings: Sequence[dict[str, object]],
|
||||
) -> list[dict[str, object]]:
|
||||
"""Keep head findings whose stable key (path, function, symbol) is absent at base.
|
||||
|
||||
Line numbers shift between revisions, so matching is by stable key only.
|
||||
A second identical symbol added inside a function that already had a
|
||||
finding collides on the key and is NOT reported here — that case is
|
||||
covered by the changed-line selection instead.
|
||||
"""
|
||||
base_keys = {_stable_key(finding) for finding in base_findings}
|
||||
return [finding for finding in head_findings if _stable_key(finding) not in base_keys]
|
||||
|
||||
|
||||
def find_changed_blocking_io(base: str, repo_root: Path = REPO_ROOT) -> list[dict[str, object]]:
|
||||
"""Return static findings this change introduces or touches.
|
||||
|
||||
Union over the changed files of:
|
||||
- findings whose blocking line is on an added line of the diff;
|
||||
- findings new versus the merge base (a new async caller can expose an
|
||||
untouched sync helper — the blocking line itself is not in the diff).
|
||||
"""
|
||||
changed_lines = changed_python_lines(base, repo_root)
|
||||
if not changed_lines:
|
||||
return []
|
||||
files = [repo_root / path for path in changed_lines]
|
||||
head_findings = [finding.to_dict() for finding in static.scan_paths(files, repo_root=repo_root)]
|
||||
on_changed_lines = select_findings_on_changed_lines(head_findings, changed_lines)
|
||||
base_findings = scan_python_contents(base_python_contents(base, sorted(changed_lines), repo_root))
|
||||
new_vs_base = select_findings_new_vs_base(head_findings, base_findings)
|
||||
selected_keys = {_stable_key(finding) for finding in (*on_changed_lines, *new_vs_base)}
|
||||
return [finding for finding in head_findings if _stable_key(finding) in selected_keys]
|
||||
|
||||
|
||||
def format_report(findings: Sequence[dict[str, object]], base: str) -> str:
|
||||
if not findings:
|
||||
return (
|
||||
f"No blocking-IO candidates introduced by this change (base: {base}).\n"
|
||||
"Note: async reachability is resolved within each file only. If this change\n"
|
||||
"adds an async call into a sync helper defined in another file, check that\n"
|
||||
"helper manually (codegraph or git grep) before relying on this empty result."
|
||||
)
|
||||
lines = [
|
||||
f"Blocking-IO candidates introduced/touched by this change (base: {base}): {len(findings)}",
|
||||
"",
|
||||
]
|
||||
order = {"HIGH": 0, "MEDIUM": 1, "LOW": 2}
|
||||
for finding in sorted(findings, key=lambda f: order.get(str(f["priority"]), 9)):
|
||||
location = finding["location"] # type: ignore[index]
|
||||
call = finding["blocking_call"] # type: ignore[index]
|
||||
lines.append(f"{finding['priority']} {call['category']}/{call['operation']} {location['path']}:{location['line']} in {location['function']} exposure={finding['event_loop_exposure']}")
|
||||
lines.append(f" symbol: {call['symbol']}")
|
||||
if finding.get("code"):
|
||||
lines.append(f" code: {finding['code']}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main(argv: Sequence[str] | None = None) -> int:
|
||||
parser = argparse.ArgumentParser(description="List blocking-IO candidates this change introduces: findings on added lines plus findings new versus the merge base (diff against --base).")
|
||||
parser.add_argument("--base", default="origin/main", help="Base ref to diff against (default: origin/main).")
|
||||
parser.add_argument("--format", choices=("text", "json"), default="text", help="Output format.")
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
findings = find_changed_blocking_io(args.base)
|
||||
if args.format == "json":
|
||||
print(json.dumps(findings, indent=2))
|
||||
else:
|
||||
print(format_report(findings, args.base))
|
||||
return 0
|
||||
@@ -1,9 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Static inventory for likely backend event-loop blocking IO.
|
||||
|
||||
This detector parses backend business source with AST so untested paths are
|
||||
still visible during review. Findings are prioritized static candidates, not
|
||||
automatic bug decisions.
|
||||
|
||||
Not directly executable: import as `support.detectors.blocking_io_static` or
|
||||
run via the CLI shim `scripts/detect_blocking_io_static.py`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -12,13 +14,14 @@ import argparse
|
||||
import ast
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections import Counter, defaultdict, deque
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[4]
|
||||
from support.detectors.repo_root import resolve_repo_root
|
||||
|
||||
REPO_ROOT = resolve_repo_root(Path(__file__))
|
||||
DEFAULT_SCAN_PATHS = (
|
||||
REPO_ROOT / "backend" / "app",
|
||||
REPO_ROOT / "backend" / "packages" / "harness" / "deerflow",
|
||||
@@ -717,12 +720,11 @@ def _finalize_findings(visitor: BlockingIOStaticVisitor) -> list[BlockingIOStati
|
||||
return findings
|
||||
|
||||
|
||||
def scan_file(path: Path, *, repo_root: Path = REPO_ROOT) -> list[BlockingIOStaticFinding]:
|
||||
source = path.read_text(encoding="utf-8")
|
||||
def scan_source(source: str, relative_path: str) -> list[BlockingIOStaticFinding]:
|
||||
"""Scan one in-memory Python source; `relative_path` is reported verbatim in findings."""
|
||||
source_lines = source.splitlines()
|
||||
relative_path = relative_to_repo(path, repo_root)
|
||||
try:
|
||||
tree = ast.parse(source, filename=str(path))
|
||||
tree = ast.parse(source, filename=relative_path)
|
||||
except SyntaxError as exc:
|
||||
line = exc.lineno or 0
|
||||
code = _source_snippet(source_lines, line)
|
||||
@@ -746,6 +748,10 @@ def scan_file(path: Path, *, repo_root: Path = REPO_ROOT) -> list[BlockingIOStat
|
||||
return sorted(_finalize_findings(visitor), key=lambda finding: (finding.path, finding.line, finding.column, finding.category))
|
||||
|
||||
|
||||
def scan_file(path: Path, *, repo_root: Path = REPO_ROOT) -> list[BlockingIOStaticFinding]:
|
||||
return scan_source(path.read_text(encoding="utf-8"), relative_to_repo(path, repo_root))
|
||||
|
||||
|
||||
def is_ignored_path(path: Path) -> bool:
|
||||
return any(part in IGNORED_DIR_NAMES for part in path.parts)
|
||||
|
||||
@@ -886,7 +892,3 @@ def main(argv: Sequence[str] | None = None) -> int:
|
||||
else:
|
||||
print(format_text(findings))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Fail-loud repository-root resolution shared by the detectors.
|
||||
|
||||
Depth-indexed resolution (`Path(__file__).resolve().parents[N]`) fails
|
||||
silently when a detector file moves to a different directory depth: scan
|
||||
roots resolve under the wrong directory, nothing is scanned, and the
|
||||
detector reports zero findings with no error. Walking upward to a
|
||||
repository marker turns that into an immediate error instead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
REPO_ROOT_MARKER = ".git"
|
||||
|
||||
|
||||
def resolve_repo_root(start: Path) -> Path:
|
||||
"""Return the repository root above `start` (the directory containing `.git`).
|
||||
|
||||
`.git` is checked with `exists()` rather than `is_dir()` so git worktrees
|
||||
(where `.git` is a file) resolve correctly.
|
||||
|
||||
Raises:
|
||||
RuntimeError: when no marker is found above `start`, so a relocated
|
||||
detector fails loudly instead of silently scanning an empty tree.
|
||||
"""
|
||||
resolved = start.resolve()
|
||||
for candidate in (resolved, *resolved.parents):
|
||||
if (candidate / REPO_ROOT_MARKER).exists():
|
||||
return candidate
|
||||
raise RuntimeError(f"could not resolve the repository root: no '{REPO_ROOT_MARKER}' marker found above {resolved}; refusing to guess scan paths")
|
||||
@@ -1,9 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Inventory async/thread boundary points for developer review.
|
||||
|
||||
This detector is intentionally non-invasive: it parses Python source with AST
|
||||
and reports places where code crosses sync/async/thread boundaries. Findings
|
||||
are review evidence, not automatic bug decisions.
|
||||
|
||||
Not directly executable: import as `support.detectors.thread_boundaries` or
|
||||
run via the CLI shim `scripts/detect_thread_boundaries.py`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -12,12 +14,13 @@ import argparse
|
||||
import ast
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[4]
|
||||
from support.detectors.repo_root import resolve_repo_root
|
||||
|
||||
REPO_ROOT = resolve_repo_root(Path(__file__))
|
||||
DEFAULT_SCAN_PATHS = (
|
||||
REPO_ROOT / "backend" / "app",
|
||||
REPO_ROOT / "backend" / "packages" / "harness" / "deerflow",
|
||||
@@ -501,7 +504,3 @@ def main(argv: Sequence[str] | None = None) -> int:
|
||||
else:
|
||||
print(format_text(findings))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
@@ -0,0 +1,251 @@
|
||||
"""Connection binding tests for browser-connectable IM channels beyond Telegram/Slack/Discord."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from app.channels.message_bus import InboundMessage, MessageBus
|
||||
|
||||
|
||||
async def _make_repo(tmp_path, name: str):
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
||||
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||
|
||||
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / f'{name}.db'}", sqlite_dir=str(tmp_path))
|
||||
return ChannelConnectionRepository(get_session_factory())
|
||||
|
||||
|
||||
async def _seed_state(repo, provider: str, state: str, owner_user_id: str = "deerflow-user-1") -> None:
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id=owner_user_id,
|
||||
provider=provider,
|
||||
state=state,
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=5),
|
||||
)
|
||||
|
||||
|
||||
def test_feishu_connect_command_binds_identity(tmp_path):
|
||||
import anyio
|
||||
|
||||
from app.channels.feishu import FeishuChannel
|
||||
|
||||
async def go():
|
||||
repo = await _make_repo(tmp_path, "feishu")
|
||||
state = "feishu-bind-code"
|
||||
await _seed_state(repo, "feishu", state)
|
||||
channel = FeishuChannel(
|
||||
bus=MessageBus(),
|
||||
config={"app_id": "app", "app_secret": "secret", "connection_repo": repo},
|
||||
)
|
||||
channel._reply_card = AsyncMock()
|
||||
|
||||
handled = await channel._bind_connection_from_connect_code(
|
||||
message_id="om-message-1",
|
||||
chat_id="oc-chat-1",
|
||||
user_id="ou-user-1",
|
||||
code=state,
|
||||
)
|
||||
|
||||
connections = await repo.list_connections("deerflow-user-1")
|
||||
assert handled is True
|
||||
assert len(connections) == 1
|
||||
assert connections[0]["provider"] == "feishu"
|
||||
assert connections[0]["external_account_id"] == "ou-user-1"
|
||||
assert connections[0]["workspace_id"] == "oc-chat-1"
|
||||
channel._reply_card.assert_awaited_once_with("om-message-1", "Feishu connected to DeerFlow.")
|
||||
await repo.close()
|
||||
|
||||
anyio.run(go)
|
||||
|
||||
|
||||
def test_dingtalk_connect_command_binds_identity(tmp_path):
|
||||
import anyio
|
||||
|
||||
from app.channels.dingtalk import _CONVERSATION_TYPE_GROUP, DingTalkChannel
|
||||
|
||||
async def go():
|
||||
repo = await _make_repo(tmp_path, "dingtalk")
|
||||
state = "dingtalk-bind-code"
|
||||
await _seed_state(repo, "dingtalk", state)
|
||||
channel = DingTalkChannel(
|
||||
bus=MessageBus(),
|
||||
config={"client_id": "client", "client_secret": "secret", "connection_repo": repo},
|
||||
)
|
||||
channel._send_connection_reply = AsyncMock()
|
||||
|
||||
handled = await channel._bind_connection_from_connect_code(
|
||||
conversation_type=_CONVERSATION_TYPE_GROUP,
|
||||
sender_staff_id="staff-user-1",
|
||||
sender_nick="Alice",
|
||||
conversation_id="cid-group-1",
|
||||
code=state,
|
||||
)
|
||||
|
||||
connections = await repo.list_connections("deerflow-user-1")
|
||||
assert handled is True
|
||||
assert len(connections) == 1
|
||||
assert connections[0]["provider"] == "dingtalk"
|
||||
assert connections[0]["external_account_id"] == "staff-user-1"
|
||||
assert connections[0]["external_account_name"] == "Alice"
|
||||
assert connections[0]["workspace_id"] == "cid-group-1"
|
||||
channel._send_connection_reply.assert_awaited_once()
|
||||
await repo.close()
|
||||
|
||||
anyio.run(go)
|
||||
|
||||
|
||||
def test_wechat_connect_command_binds_identity(tmp_path):
|
||||
import anyio
|
||||
|
||||
from app.channels.wechat import WechatChannel
|
||||
|
||||
async def go():
|
||||
repo = await _make_repo(tmp_path, "wechat")
|
||||
state = "wechat-bind-code"
|
||||
await _seed_state(repo, "wechat", state)
|
||||
channel = WechatChannel(
|
||||
bus=MessageBus(),
|
||||
config={"bot_token": "token", "connection_repo": repo},
|
||||
)
|
||||
channel._send_connection_reply = AsyncMock()
|
||||
|
||||
handled = await channel._bind_connection_from_connect_code(
|
||||
chat_id="wx-user-1",
|
||||
context_token="ctx-1",
|
||||
code=state,
|
||||
)
|
||||
|
||||
connections = await repo.list_connections("deerflow-user-1")
|
||||
assert handled is True
|
||||
assert len(connections) == 1
|
||||
assert connections[0]["provider"] == "wechat"
|
||||
assert connections[0]["external_account_id"] == "wx-user-1"
|
||||
assert connections[0]["workspace_id"] == "wx-user-1"
|
||||
channel._send_connection_reply.assert_awaited_once_with("wx-user-1", "ctx-1", "WeChat connected to DeerFlow.")
|
||||
await repo.close()
|
||||
|
||||
anyio.run(go)
|
||||
|
||||
|
||||
def test_wecom_connect_command_binds_identity(tmp_path):
|
||||
import anyio
|
||||
|
||||
from app.channels.wecom import WeComChannel
|
||||
|
||||
async def go():
|
||||
repo = await _make_repo(tmp_path, "wecom")
|
||||
state = "wecom-bind-code"
|
||||
await _seed_state(repo, "wecom", state)
|
||||
channel = WeComChannel(
|
||||
bus=MessageBus(),
|
||||
config={"bot_id": "bot", "bot_secret": "secret", "connection_repo": repo},
|
||||
)
|
||||
channel._ws_client = MagicMock()
|
||||
channel._ws_client.reply = AsyncMock()
|
||||
frame = {"body": {"aibotid": "bot-1", "chattype": "single"}}
|
||||
|
||||
handled = await channel._bind_connection_from_connect_code(
|
||||
frame=frame,
|
||||
user_id="wecom-user-1",
|
||||
code=state,
|
||||
)
|
||||
|
||||
connections = await repo.list_connections("deerflow-user-1")
|
||||
assert handled is True
|
||||
assert len(connections) == 1
|
||||
assert connections[0]["provider"] == "wecom"
|
||||
assert connections[0]["external_account_id"] == "wecom-user-1"
|
||||
assert connections[0]["workspace_id"] == "bot-1"
|
||||
channel._ws_client.reply.assert_awaited_once_with(frame, {"msgtype": "text", "text": {"content": "WeCom connected to DeerFlow."}})
|
||||
await repo.close()
|
||||
|
||||
anyio.run(go)
|
||||
|
||||
|
||||
def test_additional_channels_attach_owner_identity(tmp_path):
|
||||
import anyio
|
||||
|
||||
from app.channels.dingtalk import _CONVERSATION_TYPE_GROUP, DingTalkChannel
|
||||
from app.channels.feishu import FeishuChannel
|
||||
from app.channels.wechat import WechatChannel
|
||||
from app.channels.wecom import WeComChannel
|
||||
|
||||
async def go():
|
||||
repo = await _make_repo(tmp_path, "additional-identity")
|
||||
await repo.upsert_connection(
|
||||
owner_user_id="deerflow-user-1",
|
||||
provider="feishu",
|
||||
external_account_id="ou-user-1",
|
||||
workspace_id="oc-chat-1",
|
||||
)
|
||||
await repo.upsert_connection(
|
||||
owner_user_id="deerflow-user-1",
|
||||
provider="dingtalk",
|
||||
external_account_id="staff-user-1",
|
||||
workspace_id="cid-group-1",
|
||||
)
|
||||
await repo.upsert_connection(
|
||||
owner_user_id="deerflow-user-1",
|
||||
provider="wechat",
|
||||
external_account_id="wx-user-1",
|
||||
workspace_id="wx-user-1",
|
||||
)
|
||||
await repo.upsert_connection(
|
||||
owner_user_id="deerflow-user-1",
|
||||
provider="wecom",
|
||||
external_account_id="wecom-user-1",
|
||||
workspace_id="bot-1",
|
||||
)
|
||||
|
||||
cases = [
|
||||
(
|
||||
FeishuChannel(bus=MessageBus(), config={"connection_repo": repo}),
|
||||
InboundMessage(channel_name="feishu", chat_id="oc-chat-1", user_id="ou-user-1", text="hello"),
|
||||
),
|
||||
(
|
||||
DingTalkChannel(bus=MessageBus(), config={"connection_repo": repo}),
|
||||
InboundMessage(
|
||||
channel_name="dingtalk",
|
||||
chat_id="cid-group-1",
|
||||
user_id="staff-user-1",
|
||||
text="hello",
|
||||
metadata={
|
||||
"conversation_type": _CONVERSATION_TYPE_GROUP,
|
||||
"conversation_id": "cid-group-1",
|
||||
},
|
||||
),
|
||||
),
|
||||
(
|
||||
WechatChannel(bus=MessageBus(), config={"connection_repo": repo}),
|
||||
InboundMessage(channel_name="wechat", chat_id="wx-user-1", user_id="wx-user-1", text="hello"),
|
||||
),
|
||||
(
|
||||
WeComChannel(bus=MessageBus(), config={"connection_repo": repo}),
|
||||
InboundMessage(
|
||||
channel_name="wecom",
|
||||
chat_id="wecom-user-1",
|
||||
user_id="wecom-user-1",
|
||||
text="hello",
|
||||
metadata={"aibotid": "bot-1"},
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
for channel, inbound in cases:
|
||||
attached = await channel._attach_connection_identity(inbound)
|
||||
assert attached.owner_user_id == "deerflow-user-1"
|
||||
assert attached.connection_id
|
||||
assert (
|
||||
attached.workspace_id
|
||||
== {
|
||||
"feishu": "oc-chat-1",
|
||||
"dingtalk": "cid-group-1",
|
||||
"wechat": "wx-user-1",
|
||||
"wecom": "bot-1",
|
||||
}[channel.name]
|
||||
)
|
||||
|
||||
await repo.close()
|
||||
|
||||
anyio.run(go)
|
||||
@@ -1,7 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.community.aio_sandbox.local_backend import (
|
||||
LocalContainerBackend,
|
||||
_format_container_command_for_log,
|
||||
@@ -234,3 +237,99 @@ def test_start_container_keeps_apple_container_port_format(monkeypatch):
|
||||
captured_cmd = _capture_start_container_command(monkeypatch, backend, runtime="container")
|
||||
|
||||
assert captured_cmd[captured_cmd.index("-p") + 1] == "18080:8080"
|
||||
|
||||
|
||||
def _backend_for_inspect_tests() -> LocalContainerBackend:
|
||||
backend = LocalContainerBackend(
|
||||
image="sandbox:latest",
|
||||
base_port=8080,
|
||||
container_prefix="sandbox",
|
||||
config_mounts=[],
|
||||
environment={},
|
||||
)
|
||||
backend._runtime = "docker"
|
||||
return backend
|
||||
|
||||
|
||||
def test_is_container_running_false_when_container_missing(monkeypatch):
|
||||
backend = _backend_for_inspect_tests()
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
return SimpleNamespace(stdout="", stderr="Error: No such object: sandbox-missing", returncode=1)
|
||||
|
||||
monkeypatch.setattr("subprocess.run", fake_run)
|
||||
|
||||
assert backend._is_container_running("sandbox-missing") is False
|
||||
|
||||
|
||||
def test_is_container_running_raises_on_runtime_error(monkeypatch):
|
||||
backend = _backend_for_inspect_tests()
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
return SimpleNamespace(stdout="", stderr="Cannot connect to the Docker daemon", returncode=1)
|
||||
|
||||
monkeypatch.setattr("subprocess.run", fake_run)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Failed to inspect container sandbox-busy"):
|
||||
backend._is_container_running("sandbox-busy")
|
||||
|
||||
|
||||
def test_is_container_running_raises_on_timeout(monkeypatch):
|
||||
backend = _backend_for_inspect_tests()
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
raise subprocess.TimeoutExpired(cmd=cmd, timeout=kwargs["timeout"])
|
||||
|
||||
monkeypatch.setattr("subprocess.run", fake_run)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Timed out checking container sandbox-timeout"):
|
||||
backend._is_container_running("sandbox-timeout")
|
||||
|
||||
|
||||
def test_discover_returns_none_when_runtime_check_fails(monkeypatch):
|
||||
"""A transient daemon error during discovery must fall through to create, not fail acquire."""
|
||||
backend = _backend_for_inspect_tests()
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
return SimpleNamespace(stdout="", stderr="Cannot connect to the Docker daemon", returncode=1)
|
||||
|
||||
monkeypatch.setattr("subprocess.run", fake_run)
|
||||
|
||||
assert backend.discover("sandbox-blip") is None
|
||||
|
||||
|
||||
def test_discover_returns_none_when_runtime_check_times_out(monkeypatch):
|
||||
"""An inspect timeout during discovery must not propagate out of discover()."""
|
||||
backend = _backend_for_inspect_tests()
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
raise subprocess.TimeoutExpired(cmd=cmd, timeout=kwargs["timeout"])
|
||||
|
||||
monkeypatch.setattr("subprocess.run", fake_run)
|
||||
|
||||
assert backend.discover("sandbox-timeout") is None
|
||||
|
||||
|
||||
def test_is_container_running_false_on_apple_container_not_found(monkeypatch):
|
||||
"""Apple Container's generic "not found" is trusted when it names the container."""
|
||||
backend = _backend_for_inspect_tests()
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
return SimpleNamespace(stdout="", stderr='Error: not found: "sandbox-apple"', returncode=1)
|
||||
|
||||
monkeypatch.setattr("subprocess.run", fake_run)
|
||||
|
||||
assert backend._is_container_running("sandbox-apple") is False
|
||||
|
||||
|
||||
def test_is_container_running_raises_on_unrelated_not_found_error(monkeypatch):
|
||||
"""Transient errors whose text contains "not found" must not be misread as a dead container."""
|
||||
backend = _backend_for_inspect_tests()
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
return SimpleNamespace(stdout="", stderr="Error: credential helper not found in $PATH", returncode=1)
|
||||
|
||||
monkeypatch.setattr("subprocess.run", fake_run)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Failed to inspect container sandbox-busy"):
|
||||
backend._is_container_running("sandbox-busy")
|
||||
|
||||
@@ -317,6 +317,28 @@ async def test_acquire_async_cancelled_waiter_does_not_block_successor(tmp_path,
|
||||
pytest.fail("provider thread lock was not released after successor acquire_async")
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_acquire_internal_async_offloads_cached_reuse_health_check(tmp_path, monkeypatch):
|
||||
"""Async cached reuse must keep backend health checks off the event loop."""
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
provider, _sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-cached-async")
|
||||
provider._thread_sandboxes = {"thread-cached-async": "sandbox-cached-async"}
|
||||
provider._backend.is_alive = MagicMock(return_value=True)
|
||||
|
||||
to_thread_calls: list[tuple[object, tuple[object, ...]]] = []
|
||||
|
||||
async def fake_to_thread(func, /, *args, **kwargs):
|
||||
to_thread_calls.append((func, args))
|
||||
return func(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(aio_mod.asyncio, "to_thread", fake_to_thread)
|
||||
|
||||
sandbox_id = await provider._acquire_internal_async("thread-cached-async")
|
||||
|
||||
assert sandbox_id == "sandbox-cached-async"
|
||||
assert to_thread_calls == [(provider._reuse_in_process_sandbox, ("thread-cached-async",))]
|
||||
|
||||
|
||||
def test_remote_backend_create_forwards_effective_user_id(monkeypatch):
|
||||
"""Provisioner mode must receive user_id so PVC subPath matches user isolation."""
|
||||
remote_mod = importlib.import_module("deerflow.community.aio_sandbox.remote_backend")
|
||||
@@ -424,6 +446,136 @@ def test_release_swallows_close_errors(tmp_path, caplog):
|
||||
assert "sandbox-rel-err" in provider._warm_pool
|
||||
|
||||
|
||||
def test_get_uses_in_memory_registry_only(tmp_path):
|
||||
"""get() must stay event-loop safe by avoiding backend health checks."""
|
||||
provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-dead")
|
||||
provider._backend.is_alive = MagicMock(side_effect=AssertionError("get must not call backend health checks"))
|
||||
|
||||
assert provider.get("sandbox-dead") is sandbox
|
||||
|
||||
|
||||
def test_acquire_drops_dead_cached_sandbox(tmp_path, monkeypatch):
|
||||
"""acquire() must replace a stale active cache entry after its container dies."""
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-dead")
|
||||
provider._thread_locks = {}
|
||||
provider._thread_sandboxes = {"thread-dead": "sandbox-dead"}
|
||||
provider._config = {"replicas": 3}
|
||||
provider._backend.is_alive = MagicMock(return_value=False)
|
||||
provider._backend.discover = MagicMock(return_value=None)
|
||||
provider._backend.create = MagicMock(
|
||||
return_value=aio_mod.SandboxInfo(
|
||||
sandbox_id="sandbox-dead",
|
||||
sandbox_url="http://fresh-sandbox",
|
||||
container_name="deer-flow-sandbox-sandbox-dead",
|
||||
)
|
||||
)
|
||||
|
||||
monkeypatch.setattr(aio_mod.AioSandboxProvider, "_sandbox_id_for_thread", lambda _self, _thread_id: "sandbox-dead")
|
||||
monkeypatch.setattr(aio_mod.AioSandboxProvider, "_get_extra_mounts", lambda _self, _thread_id: [])
|
||||
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None)
|
||||
monkeypatch.setattr(aio_mod, "wait_for_sandbox_ready", lambda _url, timeout=60: True)
|
||||
|
||||
sandbox_id = provider.acquire("thread-dead")
|
||||
|
||||
assert sandbox_id == "sandbox-dead"
|
||||
sandbox.close.assert_called_once_with()
|
||||
provider._backend.destroy.assert_called_once()
|
||||
provider._backend.create.assert_called_once()
|
||||
assert provider._thread_sandboxes["thread-dead"] == "sandbox-dead"
|
||||
assert provider._sandboxes["sandbox-dead"].base_url == "http://fresh-sandbox"
|
||||
|
||||
|
||||
def test_acquire_keeps_cached_sandbox_when_health_check_errors(tmp_path):
|
||||
"""Transient backend health-check errors must not destroy a tracked sandbox."""
|
||||
provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-transient")
|
||||
provider._thread_locks = {}
|
||||
provider._thread_sandboxes = {"thread-transient": "sandbox-transient"}
|
||||
provider._backend.is_alive = MagicMock(side_effect=OSError("docker daemon busy"))
|
||||
|
||||
sandbox_id = provider.acquire("thread-transient")
|
||||
|
||||
assert sandbox_id == "sandbox-transient"
|
||||
sandbox.close.assert_not_called()
|
||||
provider._backend.destroy.assert_not_called()
|
||||
assert provider._sandboxes["sandbox-transient"] is sandbox
|
||||
|
||||
|
||||
def test_drop_unhealthy_sandbox_skips_recreated_entry(tmp_path):
|
||||
"""A stale health-check result must not delete a newly registered sandbox."""
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
provider = _make_provider(tmp_path)
|
||||
provider._lock = aio_mod.threading.Lock()
|
||||
provider._warm_pool = {}
|
||||
provider._last_activity = {"sandbox-toctou": 1.0}
|
||||
provider._thread_sandboxes = {"thread-toctou": "sandbox-toctou"}
|
||||
old_info = aio_mod.SandboxInfo(sandbox_id="sandbox-toctou", sandbox_url="http://old-sandbox")
|
||||
new_info = aio_mod.SandboxInfo(sandbox_id="sandbox-toctou", sandbox_url="http://new-sandbox")
|
||||
new_sandbox = MagicMock()
|
||||
provider._sandbox_infos = {"sandbox-toctou": new_info}
|
||||
provider._sandboxes = {"sandbox-toctou": new_sandbox}
|
||||
provider._backend = SimpleNamespace(destroy=MagicMock())
|
||||
|
||||
provider._drop_unhealthy_sandbox("sandbox-toctou", "stale health check", expected_info=old_info)
|
||||
|
||||
new_sandbox.close.assert_not_called()
|
||||
provider._backend.destroy.assert_not_called()
|
||||
assert provider._sandbox_infos["sandbox-toctou"] is new_info
|
||||
assert provider._sandboxes["sandbox-toctou"] is new_sandbox
|
||||
assert provider._thread_sandboxes == {"thread-toctou": "sandbox-toctou"}
|
||||
|
||||
|
||||
def test_acquire_skips_dead_warm_pool_sandbox(tmp_path, monkeypatch):
|
||||
"""acquire() must create a fresh sandbox when the warm-pool entry died."""
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
provider = _make_provider(tmp_path)
|
||||
provider._lock = aio_mod.threading.Lock()
|
||||
provider._thread_locks = {}
|
||||
provider._sandboxes = {}
|
||||
provider._sandbox_infos = {}
|
||||
provider._thread_sandboxes = {}
|
||||
provider._last_activity = {}
|
||||
provider._warm_pool = {
|
||||
"sandbox-warm-dead": (
|
||||
aio_mod.SandboxInfo(
|
||||
sandbox_id="sandbox-warm-dead",
|
||||
sandbox_url="http://stale-sandbox",
|
||||
container_name="deer-flow-sandbox-sandbox-warm-dead",
|
||||
),
|
||||
0.0,
|
||||
)
|
||||
}
|
||||
provider._config = {"replicas": 3}
|
||||
provider._backend = SimpleNamespace(
|
||||
is_alive=MagicMock(return_value=False),
|
||||
destroy=MagicMock(),
|
||||
discover=MagicMock(return_value=None),
|
||||
create=MagicMock(
|
||||
return_value=aio_mod.SandboxInfo(
|
||||
sandbox_id="sandbox-warm-dead",
|
||||
sandbox_url="http://fresh-sandbox",
|
||||
container_name="deer-flow-sandbox-sandbox-warm-dead",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(aio_mod.AioSandboxProvider, "_sandbox_id_for_thread", lambda _self, _thread_id: "sandbox-warm-dead")
|
||||
monkeypatch.setattr(aio_mod.AioSandboxProvider, "_get_extra_mounts", lambda _self, _thread_id: [])
|
||||
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None)
|
||||
monkeypatch.setattr(aio_mod, "wait_for_sandbox_ready", lambda _url, timeout=60: True)
|
||||
|
||||
sandbox_id = provider.acquire("thread-warm-dead")
|
||||
|
||||
assert sandbox_id == "sandbox-warm-dead"
|
||||
provider._backend.destroy.assert_called_once()
|
||||
provider._backend.create.assert_called_once()
|
||||
assert provider._warm_pool == {}
|
||||
assert provider._thread_sandboxes["thread-warm-dead"] == "sandbox-warm-dead"
|
||||
assert provider._sandboxes["sandbox-warm-dead"].base_url == "http://fresh-sandbox"
|
||||
|
||||
|
||||
def test_destroy_swallows_close_errors_and_still_destroys_backend(tmp_path, caplog):
|
||||
"""A failure in sandbox.close() must not skip backend container destruction."""
|
||||
provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-dest-err")
|
||||
|
||||
@@ -280,6 +280,74 @@ def test_require_permission_denies_wrong_permission():
|
||||
assert "Permission denied" in response.json()["detail"]
|
||||
|
||||
|
||||
def _make_internal_owner_check_app():
|
||||
"""App with an owner_check route and a thread owned by ``alice``."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import Request
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||
|
||||
app = FastAPI()
|
||||
thread_store = MemoryThreadMetaStore(InMemoryStore())
|
||||
asyncio.run(thread_store.create("alice-thread", user_id="alice"))
|
||||
app.state.thread_store = thread_store
|
||||
|
||||
@app.get("/threads/{thread_id}")
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def endpoint(thread_id: str, request: Request):
|
||||
return {"ok": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _internal_auth_context() -> AuthContext:
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
|
||||
|
||||
user = SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)
|
||||
return AuthContext(user=user, permissions=[Permissions.THREADS_READ])
|
||||
|
||||
|
||||
def test_require_permission_internal_role_scoped_by_owner_header():
|
||||
"""An internal caller acting for the thread owner passes the owner check."""
|
||||
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
|
||||
|
||||
app = _make_internal_owner_check_app()
|
||||
with patch("app.gateway.authz._authenticate", return_value=_internal_auth_context()):
|
||||
with TestClient(app) as client:
|
||||
response = client.get(
|
||||
"/threads/alice-thread",
|
||||
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "alice"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_require_permission_internal_role_denied_for_other_owner():
|
||||
"""The internal token must not grant access to another user's thread."""
|
||||
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
|
||||
|
||||
app = _make_internal_owner_check_app()
|
||||
with patch("app.gateway.authz._authenticate", return_value=_internal_auth_context()):
|
||||
with TestClient(app) as client:
|
||||
response = client.get(
|
||||
"/threads/alice-thread",
|
||||
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "mallory"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_require_permission_internal_role_without_header_is_scoped_to_internal_user():
|
||||
"""With no owner header, internal callers are scoped like before the bypass."""
|
||||
app = _make_internal_owner_check_app()
|
||||
with patch("app.gateway.authz._authenticate", return_value=_internal_auth_context()):
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/threads/alice-thread")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ── Weak JWT secret warning ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from app.gateway.auth_middleware import AuthMiddleware, _is_public
|
||||
from app.gateway.csrf_middleware import CSRFMiddleware
|
||||
|
||||
# ── _is_public unit tests ─────────────────────────────────────────────────
|
||||
|
||||
@@ -38,6 +39,8 @@ def test_public_paths(path: str):
|
||||
"/api/threads/123/uploads",
|
||||
"/api/agents",
|
||||
"/api/channels",
|
||||
"/api/channels/providers",
|
||||
"/api/channels/slack/connect",
|
||||
"/api/runs/stream",
|
||||
"/api/threads/123/runs",
|
||||
"/api/v1/auth/me",
|
||||
@@ -88,7 +91,9 @@ def test_unknown_api_path_is_protected():
|
||||
|
||||
def _make_app():
|
||||
"""Create a minimal FastAPI app with AuthMiddleware for testing."""
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request
|
||||
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(AuthMiddleware)
|
||||
@@ -98,8 +103,16 @@ def _make_app():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/api/v1/auth/me")
|
||||
async def auth_me():
|
||||
return {"id": "1", "email": "test@test.com"}
|
||||
async def auth_me(request: Request):
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
user = await get_current_user_from_request(request)
|
||||
return {
|
||||
"id": str(user.id),
|
||||
"email": user.email,
|
||||
"system_role": user.system_role,
|
||||
"needs_setup": user.needs_setup,
|
||||
}
|
||||
|
||||
@app.get("/api/v1/auth/setup-status")
|
||||
async def setup_status():
|
||||
@@ -109,6 +122,29 @@ def _make_app():
|
||||
async def models_get():
|
||||
return {"models": []}
|
||||
|
||||
@app.get("/api/whoami")
|
||||
async def whoami(request: Request):
|
||||
user = request.state.user
|
||||
return {
|
||||
"id": str(user.id),
|
||||
"email": getattr(user, "email", None),
|
||||
"system_role": getattr(user, "system_role", None),
|
||||
"context_user_id": get_effective_user_id(),
|
||||
}
|
||||
|
||||
@app.get("/api/current-user-from-dep")
|
||||
async def current_user_from_dep(request: Request):
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
user = await get_current_user_from_request(request)
|
||||
state_user = request.state.user
|
||||
return {
|
||||
"id": str(user.id),
|
||||
"state_id": str(state_user.id),
|
||||
"auth_source": request.state.auth_source,
|
||||
"context_user_id": get_effective_user_id(),
|
||||
}
|
||||
|
||||
@app.put("/api/mcp/config")
|
||||
async def mcp_put():
|
||||
return {"ok": True}
|
||||
@@ -132,8 +168,24 @@ def _make_app():
|
||||
return app
|
||||
|
||||
|
||||
def _make_auth_csrf_app():
|
||||
"""Create a minimal app with production middleware ordering."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(AuthMiddleware)
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
@app.post("/api/threads/abc/runs/stream")
|
||||
async def protected_mutation():
|
||||
return {"ok": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
def client(monkeypatch):
|
||||
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "")
|
||||
return TestClient(_make_app())
|
||||
|
||||
|
||||
@@ -161,11 +213,145 @@ def test_protected_path_no_cookie_returns_401(client):
|
||||
assert body["detail"]["code"] == "not_authenticated"
|
||||
|
||||
|
||||
def test_auth_disabled_allows_protected_path_without_cookie(monkeypatch):
|
||||
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||
client = TestClient(_make_app())
|
||||
|
||||
res = client.get("/api/models")
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"models": []}
|
||||
|
||||
|
||||
def test_auth_disabled_stamps_default_admin_user_without_cookie(monkeypatch):
|
||||
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||
client = TestClient(_make_app())
|
||||
|
||||
res = client.get("/api/whoami")
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {
|
||||
"id": "default",
|
||||
"email": "default@test.local",
|
||||
"system_role": "admin",
|
||||
"context_user_id": "default",
|
||||
}
|
||||
|
||||
|
||||
def test_auth_disabled_auth_me_reuses_middleware_user_without_cookie(monkeypatch):
|
||||
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||
client = TestClient(_make_app())
|
||||
|
||||
res = client.get("/api/v1/auth/me")
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {
|
||||
"id": "default",
|
||||
"email": "default@test.local",
|
||||
"system_role": "admin",
|
||||
"needs_setup": False,
|
||||
}
|
||||
|
||||
|
||||
def test_auth_disabled_does_not_clobber_valid_session_cookie(monkeypatch):
|
||||
from types import SimpleNamespace
|
||||
|
||||
async def fake_current_user(request):
|
||||
return SimpleNamespace(
|
||||
id="session-user",
|
||||
email="session@test.local",
|
||||
system_role="user",
|
||||
needs_setup=False,
|
||||
)
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||
monkeypatch.setattr("app.gateway.deps.get_current_user_from_request", fake_current_user)
|
||||
client = TestClient(_make_app())
|
||||
|
||||
res = client.get("/api/whoami", cookies={"access_token": "valid-session"})
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {
|
||||
"id": "session-user",
|
||||
"email": "session@test.local",
|
||||
"system_role": "user",
|
||||
"context_user_id": "session-user",
|
||||
}
|
||||
|
||||
|
||||
def test_auth_disabled_does_not_clobber_internal_auth_identity(monkeypatch):
|
||||
from app.gateway.internal_auth import create_internal_auth_headers
|
||||
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||
client = TestClient(_make_app())
|
||||
|
||||
res = client.get(
|
||||
"/api/current-user-from-dep",
|
||||
headers=create_internal_auth_headers(),
|
||||
)
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {
|
||||
"id": DEFAULT_USER_ID,
|
||||
"state_id": DEFAULT_USER_ID,
|
||||
"auth_source": "internal",
|
||||
"context_user_id": DEFAULT_USER_ID,
|
||||
}
|
||||
|
||||
|
||||
def test_auth_disabled_skips_csrf_for_state_changing_requests(monkeypatch):
|
||||
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||
client = TestClient(_make_auth_csrf_app())
|
||||
|
||||
res = client.post("/api/threads/abc/runs/stream")
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"ok": True}
|
||||
|
||||
|
||||
def test_auth_disabled_is_ignored_in_explicit_production_env(monkeypatch):
|
||||
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||
monkeypatch.setenv("DEER_FLOW_ENV", "production")
|
||||
client = TestClient(_make_app())
|
||||
|
||||
res = client.get("/api/models")
|
||||
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_auth_disabled_startup_warning_when_effective(monkeypatch, caplog):
|
||||
from app.gateway.auth_disabled import warn_if_auth_disabled_enabled
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||
monkeypatch.delenv("DEER_FLOW_ENV", raising=False)
|
||||
monkeypatch.delenv("ENVIRONMENT", raising=False)
|
||||
|
||||
with caplog.at_level("WARNING", logger="app.gateway.auth_disabled"):
|
||||
warn_if_auth_disabled_enabled()
|
||||
|
||||
assert "authentication is bypassed" in caplog.text
|
||||
assert "default" in caplog.text
|
||||
|
||||
|
||||
def test_auth_disabled_startup_warning_suppressed_in_explicit_production_env(monkeypatch, caplog):
|
||||
from app.gateway.auth_disabled import warn_if_auth_disabled_enabled
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||
monkeypatch.setenv("ENVIRONMENT", "production")
|
||||
|
||||
with caplog.at_level("WARNING", logger="app.gateway.auth_disabled"):
|
||||
warn_if_auth_disabled_enabled()
|
||||
|
||||
assert "authentication is bypassed" not in caplog.text
|
||||
|
||||
|
||||
def test_protected_path_with_junk_cookie_rejected(client):
|
||||
"""Junk cookie → 401. Middleware strictly validates the JWT now
|
||||
(AUTH_TEST_PLAN test 7.5.8); it no longer silently passes bad
|
||||
tokens through to the route handler."""
|
||||
res = client.get("/api/models", cookies={"access_token": "some-token"})
|
||||
client.cookies.set("access_token", "some-token")
|
||||
res = client.get("/api/models")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,360 @@
|
||||
"""Unit tests for the Brave Search community web search tool."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_api_key_warned():
|
||||
"""Reset the module-level warning flag before each test."""
|
||||
import deerflow.community.brave.tools as brave_mod
|
||||
|
||||
brave_mod._api_key_warned = False
|
||||
yield
|
||||
brave_mod._api_key_warned = False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_with_key():
|
||||
with patch("deerflow.community.brave.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "test-brave-key", "max_results": 5}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_no_key():
|
||||
with patch("deerflow.community.brave.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
yield mock
|
||||
|
||||
|
||||
def _make_brave_response(results: list) -> MagicMock:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = {"web": {"results": results}}
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
return mock_resp
|
||||
|
||||
|
||||
def _count_aware_get(results: list):
|
||||
"""Mimic Brave returning at most `count` results for the request."""
|
||||
|
||||
def _get(url, **kwargs):
|
||||
count = kwargs["params"]["count"]
|
||||
return _make_brave_response(results[:count])
|
||||
|
||||
return _get
|
||||
|
||||
|
||||
class TestGetApiKey:
|
||||
def test_returns_config_key_when_present(self):
|
||||
with patch("deerflow.community.brave.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "from-config"}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
|
||||
from deerflow.community.brave.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() == "from-config"
|
||||
|
||||
def test_falls_back_to_env_when_config_key_empty(self):
|
||||
with patch("deerflow.community.brave.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": " "}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
with patch.dict("os.environ", {"BRAVE_SEARCH_API_KEY": "env-key"}, clear=True):
|
||||
from deerflow.community.brave.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() == "env-key"
|
||||
|
||||
def test_falls_back_to_env_when_no_config(self):
|
||||
with patch("deerflow.community.brave.tools.get_app_config") as mock:
|
||||
mock.return_value.get_tool_config.return_value = None
|
||||
with patch.dict("os.environ", {"BRAVE_SEARCH_API_KEY": "env-only"}, clear=True):
|
||||
from deerflow.community.brave.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() == "env-only"
|
||||
|
||||
def test_ignores_legacy_brave_api_key(self):
|
||||
with patch("deerflow.community.brave.tools.get_app_config") as mock:
|
||||
mock.return_value.get_tool_config.return_value = None
|
||||
with patch.dict("os.environ", {"BRAVE_API_KEY": "legacy"}, clear=True):
|
||||
from deerflow.community.brave.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() is None
|
||||
|
||||
def test_returns_none_when_no_key_anywhere(self):
|
||||
with patch("deerflow.community.brave.tools.get_app_config") as mock:
|
||||
mock.return_value.get_tool_config.return_value = None
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
from deerflow.community.brave.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() is None
|
||||
|
||||
def test_model_extra_none_does_not_crash(self):
|
||||
with patch("deerflow.community.brave.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = None
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
with patch.dict("os.environ", {"BRAVE_SEARCH_API_KEY": "env-key"}, clear=True):
|
||||
from deerflow.community.brave.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() == "env-key"
|
||||
|
||||
|
||||
class TestWebSearchTool:
|
||||
def test_basic_search_returns_normalized_results(self, mock_config_with_key):
|
||||
results = [
|
||||
{"title": "Result 1", "url": "https://example.com/1", "description": "Desc 1"},
|
||||
{"title": "Result 2", "url": "https://example.com/2", "description": "Desc 2"},
|
||||
]
|
||||
mock_resp = _make_brave_response(results)
|
||||
|
||||
with patch("deerflow.community.brave.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.get.return_value = mock_resp
|
||||
|
||||
from deerflow.community.brave.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "python tutorial"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["query"] == "python tutorial"
|
||||
assert parsed["total_results"] == 2
|
||||
assert parsed["results"][0]["title"] == "Result 1"
|
||||
assert parsed["results"][0]["url"] == "https://example.com/1"
|
||||
assert parsed["results"][0]["content"] == "Desc 1"
|
||||
|
||||
def test_respects_max_results_from_config(self, mock_config_with_key):
|
||||
mock_config_with_key.return_value.get_tool_config.return_value.model_extra = {
|
||||
"api_key": "test-key",
|
||||
"max_results": 3,
|
||||
}
|
||||
results = [{"title": f"R{i}", "url": f"https://x.com/{i}", "description": f"D{i}"} for i in range(10)]
|
||||
|
||||
with patch("deerflow.community.brave.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.get.side_effect = _count_aware_get(results)
|
||||
|
||||
from deerflow.community.brave.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["total_results"] == 3
|
||||
assert len(parsed["results"]) == 3
|
||||
|
||||
def test_max_results_parameter_accepted(self, mock_config_no_key):
|
||||
"""Tool accepts max_results as a call parameter when config does not override it."""
|
||||
results = [{"title": f"R{i}", "url": f"https://x.com/{i}", "description": f"D{i}"} for i in range(10)]
|
||||
|
||||
with patch.dict("os.environ", {"BRAVE_SEARCH_API_KEY": "env-key"}, clear=True):
|
||||
with patch("deerflow.community.brave.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.get.side_effect = _count_aware_get(results)
|
||||
|
||||
from deerflow.community.brave.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test", "max_results": 2})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["total_results"] == 2
|
||||
|
||||
def test_config_max_results_overrides_parameter(self):
|
||||
with patch("deerflow.community.brave.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "test-key", "max_results": 3}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
|
||||
results = [{"title": f"R{i}", "url": f"https://x.com/{i}", "description": f"D{i}"} for i in range(10)]
|
||||
|
||||
with patch("deerflow.community.brave.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.get.side_effect = _count_aware_get(results)
|
||||
|
||||
from deerflow.community.brave.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test", "max_results": 8})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["total_results"] == 3
|
||||
|
||||
def test_max_results_string_from_env_is_coerced_and_clamped(self):
|
||||
"""Env-sourced max_results is a string and must be coerced and clamped to 20."""
|
||||
with patch("deerflow.community.brave.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "test-key", "max_results": "50"}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
|
||||
results = [{"title": f"R{i}", "url": f"https://x.com/{i}", "description": f"D{i}"} for i in range(30)]
|
||||
|
||||
with patch("deerflow.community.brave.tools.httpx.Client") as mock_client_cls:
|
||||
mock_get = mock_client_cls.return_value.__enter__.return_value.get
|
||||
mock_get.side_effect = _count_aware_get(results)
|
||||
|
||||
from deerflow.community.brave.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
params = mock_get.call_args.kwargs["params"]
|
||||
|
||||
assert params["count"] == 20
|
||||
assert parsed["total_results"] == 20
|
||||
|
||||
def test_invalid_max_results_falls_back_to_default(self, caplog):
|
||||
with patch("deerflow.community.brave.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "test-key", "max_results": "abc"}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
|
||||
results = [{"title": f"R{i}", "url": f"https://x.com/{i}", "description": f"D{i}"} for i in range(10)]
|
||||
|
||||
with patch("deerflow.community.brave.tools.httpx.Client") as mock_client_cls:
|
||||
mock_get = mock_client_cls.return_value.__enter__.return_value.get
|
||||
mock_get.side_effect = _count_aware_get(results)
|
||||
|
||||
from deerflow.community.brave.tools import web_search_tool
|
||||
|
||||
with caplog.at_level("WARNING", logger="deerflow.community.brave.tools"):
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
params = mock_get.call_args.kwargs["params"]
|
||||
|
||||
assert params["count"] == 5
|
||||
assert parsed["total_results"] == 5
|
||||
assert any("Invalid Brave Search max_results" in record.message for record in caplog.records)
|
||||
|
||||
def test_empty_results_returns_error_json(self, mock_config_with_key):
|
||||
mock_resp = _make_brave_response([])
|
||||
|
||||
with patch("deerflow.community.brave.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.get.return_value = mock_resp
|
||||
|
||||
from deerflow.community.brave.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "no results"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["error"] == "No results found"
|
||||
assert parsed["query"] == "no results"
|
||||
|
||||
def test_missing_web_key_returns_error_json(self, mock_config_with_key):
|
||||
"""A response without a `web` block should be treated as no results."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = {}
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
|
||||
with patch("deerflow.community.brave.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.get.return_value = mock_resp
|
||||
|
||||
from deerflow.community.brave.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["error"] == "No results found"
|
||||
|
||||
def test_missing_api_key_returns_error_json(self, mock_config_no_key):
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
from deerflow.community.brave.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert "error" in parsed
|
||||
assert "BRAVE_SEARCH_API_KEY" in parsed["error"]
|
||||
|
||||
def test_missing_api_key_logs_warning_once(self, mock_config_no_key, caplog):
|
||||
import logging
|
||||
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
from deerflow.community.brave.tools import web_search_tool
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="deerflow.community.brave.tools"):
|
||||
web_search_tool.invoke({"query": "q1"})
|
||||
web_search_tool.invoke({"query": "q2"})
|
||||
|
||||
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
|
||||
assert len(warnings) == 1
|
||||
|
||||
def test_http_error_returns_structured_error(self, mock_config_with_key):
|
||||
mock_error_response = MagicMock()
|
||||
mock_error_response.status_code = 403
|
||||
mock_error_response.text = "Forbidden"
|
||||
|
||||
with patch("deerflow.community.brave.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.get.side_effect = httpx.HTTPStatusError("403", request=MagicMock(), response=mock_error_response)
|
||||
|
||||
from deerflow.community.brave.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert "error" in parsed
|
||||
assert "403" in parsed["error"]
|
||||
|
||||
def test_network_exception_returns_error_json(self, mock_config_with_key):
|
||||
with patch("deerflow.community.brave.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.get.side_effect = Exception("timeout")
|
||||
|
||||
from deerflow.community.brave.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert "error" in parsed
|
||||
|
||||
def test_sends_correct_headers_and_params(self, mock_config_with_key):
|
||||
results = [{"title": "T", "url": "https://x.com", "description": "D"}]
|
||||
mock_resp = _make_brave_response(results)
|
||||
|
||||
with patch("deerflow.community.brave.tools.httpx.Client") as mock_client_cls:
|
||||
mock_get = mock_client_cls.return_value.__enter__.return_value.get
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
from deerflow.community.brave.tools import web_search_tool
|
||||
|
||||
web_search_tool.invoke({"query": "hello world"})
|
||||
|
||||
call_kwargs = mock_get.call_args
|
||||
headers = call_kwargs.kwargs["headers"]
|
||||
params = call_kwargs.kwargs["params"]
|
||||
|
||||
assert headers["X-Subscription-Token"] == "test-brave-key"
|
||||
assert params["q"] == "hello world"
|
||||
assert params["count"] == 5
|
||||
|
||||
def test_uses_env_key_when_config_absent(self):
|
||||
with patch("deerflow.community.brave.tools.get_app_config") as mock:
|
||||
mock.return_value.get_tool_config.return_value = None
|
||||
with patch.dict("os.environ", {"BRAVE_SEARCH_API_KEY": "env-only-key"}, clear=True):
|
||||
results = [{"title": "T", "url": "https://x.com", "description": "D"}]
|
||||
mock_resp = _make_brave_response(results)
|
||||
|
||||
with patch("deerflow.community.brave.tools.httpx.Client") as mock_client_cls:
|
||||
mock_get = mock_client_cls.return_value.__enter__.return_value.get
|
||||
mock_get.return_value = mock_resp
|
||||
|
||||
from deerflow.community.brave.tools import web_search_tool
|
||||
|
||||
web_search_tool.invoke({"query": "env key test"})
|
||||
headers = mock_get.call_args.kwargs["headers"]
|
||||
|
||||
assert headers["X-Subscription-Token"] == "env-only-key"
|
||||
|
||||
def test_partial_fields_in_result(self, mock_config_with_key):
|
||||
"""Missing title/url/description should default to empty string."""
|
||||
results = [{}]
|
||||
mock_resp = _make_brave_response(results)
|
||||
|
||||
with patch("deerflow.community.brave.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.get.return_value = mock_resp
|
||||
|
||||
from deerflow.community.brave.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["results"][0] == {"title": "", "url": "", "content": ""}
|
||||
@@ -0,0 +1,187 @@
|
||||
"""Tests for Browserless community tools."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.community.browserless import tools
|
||||
from deerflow.community.browserless.browserless_client import BrowserlessClient
|
||||
|
||||
|
||||
class AsyncMock(MagicMock):
|
||||
"""Mock that supports async call."""
|
||||
|
||||
async def __call__(self, *args, **kwargs):
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestBrowserlessClient:
|
||||
"""Tests for the BrowserlessClient class."""
|
||||
|
||||
async def test_fetch_html_success(self):
|
||||
"""fetch_html returns HTML content on success."""
|
||||
with patch("deerflow.community.browserless.browserless_client.httpx.AsyncClient") as mock_cls:
|
||||
mock_ctx = MagicMock()
|
||||
mock_cls.return_value.__aenter__.return_value = mock_ctx
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.text = "<html><body>Page content</body></html>"
|
||||
mock_resp.headers = {}
|
||||
mock_ctx.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
client = BrowserlessClient(base_url="http://browserless:3000")
|
||||
result = await client.fetch_html("https://example.com")
|
||||
|
||||
assert result == "<html><body>Page content</body></html>"
|
||||
call_kwargs = mock_ctx.post.call_args.kwargs
|
||||
assert call_kwargs["json"]["url"] == "https://example.com"
|
||||
assert "waitUntil" not in call_kwargs["json"]
|
||||
assert "gotoTimeout" not in call_kwargs["json"]
|
||||
assert "bestAttempt" not in call_kwargs["json"]
|
||||
|
||||
async def test_fetch_html_empty_response(self):
|
||||
"""fetch_html returns error for empty response."""
|
||||
with patch("deerflow.community.browserless.browserless_client.httpx.AsyncClient") as mock_cls:
|
||||
mock_ctx = MagicMock()
|
||||
mock_cls.return_value.__aenter__.return_value = mock_ctx
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.text = " "
|
||||
mock_resp.headers = {}
|
||||
mock_ctx.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
client = BrowserlessClient(base_url="http://browserless:3000")
|
||||
result = await client.fetch_html("https://example.com")
|
||||
assert result == "Error: Browserless returned empty response"
|
||||
|
||||
async def test_fetch_html_http_error(self):
|
||||
"""fetch_html returns error for non-200 status."""
|
||||
with patch("deerflow.community.browserless.browserless_client.httpx.AsyncClient") as mock_cls:
|
||||
mock_ctx = MagicMock()
|
||||
mock_cls.return_value.__aenter__.return_value = mock_ctx
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 500
|
||||
mock_resp.text = "Internal error"
|
||||
mock_resp.headers = {}
|
||||
mock_ctx.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
client = BrowserlessClient(base_url="http://browserless:3000")
|
||||
result = await client.fetch_html("https://example.com")
|
||||
assert "Error: Browserless HTTP 500" in result
|
||||
|
||||
async def test_fetch_html_timeout(self):
|
||||
"""fetch_html returns timeout error."""
|
||||
with patch("deerflow.community.browserless.browserless_client.httpx.AsyncClient") as mock_cls:
|
||||
mock_ctx = MagicMock()
|
||||
mock_cls.return_value.__aenter__.return_value = mock_ctx
|
||||
import httpx
|
||||
|
||||
mock_ctx.post = AsyncMock(side_effect=httpx.TimeoutException("Timed out"))
|
||||
|
||||
client = BrowserlessClient(base_url="http://browserless:3000", timeout_s=10)
|
||||
result = await client.fetch_html("https://example.com")
|
||||
assert "timed out" in result.lower() or "timeout" in result.lower()
|
||||
|
||||
async def test_fetch_html_with_token(self):
|
||||
"""fetch_html includes token in payload when set."""
|
||||
with patch("deerflow.community.browserless.browserless_client.httpx.AsyncClient") as mock_cls:
|
||||
mock_ctx = MagicMock()
|
||||
mock_cls.return_value.__aenter__.return_value = mock_ctx
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.text = "<html>OK</html>"
|
||||
mock_resp.headers = {}
|
||||
mock_ctx.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
client = BrowserlessClient(base_url="http://browserless:3000", token="my-token")
|
||||
await client.fetch_html("https://example.com")
|
||||
|
||||
payload = mock_ctx.post.call_args.kwargs["json"]
|
||||
assert payload["token"] == "my-token"
|
||||
|
||||
async def test_fetch_html_with_wait_for_selector(self):
|
||||
"""fetch_html sends waitForSelector when selector is set."""
|
||||
with patch("deerflow.community.browserless.browserless_client.httpx.AsyncClient") as mock_cls:
|
||||
mock_ctx = MagicMock()
|
||||
mock_cls.return_value.__aenter__.return_value = mock_ctx
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.text = "<html>OK</html>"
|
||||
mock_resp.headers = {}
|
||||
mock_ctx.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
client = BrowserlessClient(base_url="http://browserless:3000")
|
||||
await client.fetch_html("https://example.com", wait_for_selector="article")
|
||||
|
||||
payload = mock_ctx.post.call_args.kwargs["json"]
|
||||
assert payload["waitForSelector"]["selector"] == "article"
|
||||
|
||||
async def test_fetch_html_with_reject_params(self):
|
||||
"""fetch_html sends reject params when set."""
|
||||
with patch("deerflow.community.browserless.browserless_client.httpx.AsyncClient") as mock_cls:
|
||||
mock_ctx = MagicMock()
|
||||
mock_cls.return_value.__aenter__.return_value = mock_ctx
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.text = "<html>OK</html>"
|
||||
mock_resp.headers = {}
|
||||
mock_ctx.post = AsyncMock(return_value=mock_resp)
|
||||
|
||||
client = BrowserlessClient(base_url="http://browserless:3000")
|
||||
await client.fetch_html(
|
||||
"https://example.com",
|
||||
reject_resource_types=["image"],
|
||||
reject_request_pattern=[r"\.css$"],
|
||||
)
|
||||
|
||||
payload = mock_ctx.post.call_args.kwargs["json"]
|
||||
assert payload["rejectResourceTypes"] == ["image"]
|
||||
assert payload["rejectRequestPattern"] == [r"\.css$"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestBrowserlessTools:
|
||||
"""Tests for the Browserless tool functions."""
|
||||
|
||||
@patch("deerflow.community.browserless.tools._get_browserless_client")
|
||||
async def test_web_fetch_tool_success(self, mock_get_client):
|
||||
"""web_fetch_tool successfully fetches and extracts content."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.fetch_html = AsyncMock(return_value="<html><body><article><h1>Title</h1><p>Content</p></article></body></html>")
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
with patch("deerflow.community.browserless.tools._get_tool_config", return_value=None):
|
||||
result = await tools.web_fetch_tool.ainvoke("https://example.com/article")
|
||||
|
||||
assert "Error:" not in result
|
||||
|
||||
@patch("deerflow.community.browserless.tools._get_browserless_client")
|
||||
async def test_web_fetch_tool_error(self, mock_get_client):
|
||||
"""web_fetch_tool returns error when fetch fails."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.fetch_html = AsyncMock(return_value="Error: Browserless returned empty response")
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
with patch("deerflow.community.browserless.tools._get_tool_config", return_value=None):
|
||||
result = await tools.web_fetch_tool.ainvoke("https://example.com")
|
||||
|
||||
assert result.startswith("Error:")
|
||||
|
||||
@patch("deerflow.community.browserless.tools._get_browserless_client")
|
||||
async def test_web_fetch_tool_exception(self, mock_get_client):
|
||||
"""web_fetch_tool returns error when client raises exception."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.fetch_html = AsyncMock(side_effect=Exception("Unexpected error"))
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
with patch("deerflow.community.browserless.tools._get_tool_config", return_value=None):
|
||||
result = await tools.web_fetch_tool.ainvoke("https://example.com")
|
||||
|
||||
assert result.startswith("Error:")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user