mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-13 10:55:59 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0fdfbae435 | |||
| 150d03f2e7 | |||
| 9593214065 |
@@ -1,141 +0,0 @@
|
|||||||
---
|
|
||||||
name: blocking-io-guard
|
|
||||||
description: Ensure async-path backend code that could block the asyncio event loop is protected by a teeth-verified runtime anchor in tests/blocking_io/. Use when changing backend Python under app/, packages/harness/deerflow/, or scripts/, when running a blocking-IO triage round over the whole repo, or when a reviewer/CI asks for blocking-IO coverage. Runs a deterministic scan (changed-lines or full-repo), routes each candidate, drafts/extends an anchor, and proves it fails when the blocking IO regresses.
|
|
||||||
---
|
|
||||||
|
|
||||||
# Blocking-IO Guard Skill
|
|
||||||
|
|
||||||
Help a contributor ship backend async changes together with the runtime anchor
|
|
||||||
that lets DeerFlow's blocking-IO CI gate actually see the new code. The dynamic
|
|
||||||
detector only catches blocking IO on paths a test executes — this skill closes
|
|
||||||
that gap, either for your own diff or for a repo-wide triage round.
|
|
||||||
|
|
||||||
Read `references/good-anchor-rules.md` before writing any anchor.
|
|
||||||
Only read `references/sop-skeleton.md` when generalizing this SOP to another
|
|
||||||
detector domain — it is not needed to execute the steps below.
|
|
||||||
|
|
||||||
## When to use
|
|
||||||
|
|
||||||
- Your change touches Python under `backend/app/`,
|
|
||||||
`backend/packages/harness/deerflow/`, or `backend/scripts/` and may run on
|
|
||||||
the async event loop (Mode A). If unsure, run Step 0 — it answers
|
|
||||||
deterministically.
|
|
||||||
- You are doing a maintenance triage round over the existing codebase
|
|
||||||
(Mode B).
|
|
||||||
|
|
||||||
## SOP (router)
|
|
||||||
|
|
||||||
### Step 0 — Scope (deterministic)
|
|
||||||
|
|
||||||
**Mode A — your own diff** (default, pre-PR). From repo root:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv run --project backend python scripts/scan_changed_blocking_io.py --base origin/main
|
|
||||||
```
|
|
||||||
|
|
||||||
Lists blocking-IO candidates your change introduces: findings on lines the
|
|
||||||
diff added, **plus** findings that are new versus the merge base — the latter
|
|
||||||
catches a new async caller exposing an old sync helper whose blocking line is
|
|
||||||
not in the diff. The diff is `<base>...HEAD`, so **commit your work first** —
|
|
||||||
uncommitted lines are not selected.
|
|
||||||
|
|
||||||
If the list is empty, this change introduces no blocking-IO surface *that the
|
|
||||||
static detector can see in the changed files*. One residual blind spot
|
|
||||||
remains: reachability is same-file only, so a new async caller of a sync
|
|
||||||
helper **defined in another file** is invisible to both selections. If your
|
|
||||||
diff adds an async call into a helper that lives elsewhere, check that helper
|
|
||||||
manually (codegraph or `git grep`) before stopping.
|
|
||||||
|
|
||||||
**Mode B — full-repo triage round.** From repo root:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
make detect-blocking-io
|
|
||||||
```
|
|
||||||
|
|
||||||
Prints a summary and writes the complete structured finding list to
|
|
||||||
`.deer-flow/blocking-io-findings.json`. Work HIGH priority first; do not start
|
|
||||||
MEDIUM until every HIGH is dispositioned (fixed, guarded, or recorded
|
|
||||||
NO-ACTION).
|
|
||||||
|
|
||||||
**Batching policy (PR sizing).** One **fix unit** per PR while any HIGH
|
|
||||||
remains: a fix unit is one root cause — usually a single HIGH, but two HIGHs
|
|
||||||
resolved by the same one-place fix belong together. Once no HIGH remains,
|
|
||||||
MEDIUM/LOW may be batched (about five per round, grouped by module or by
|
|
||||||
disposition) so each PR stays reviewable. A new Blockbuster rule is never
|
|
||||||
batched with anything — it always ships alone (see Step 5).
|
|
||||||
|
|
||||||
Both modes emit the same JSON shape per finding: `priority`, `location`
|
|
||||||
(path/line/function), `blocking_call` (category/operation/symbol),
|
|
||||||
`event_loop_exposure`, `reason`, `code`. Priority is a deterministic review
|
|
||||||
ordering, not proof of a bug — Step 1 makes the actual call.
|
|
||||||
|
|
||||||
### Step 1 — Judge each candidate (router)
|
|
||||||
|
|
||||||
Read the code around each candidate and route it:
|
|
||||||
|
|
||||||
- **Already offloaded** (`asyncio.to_thread`, `run_in_executor`, async client) →
|
|
||||||
**GUARD**: add/extend an anchor that locks the offload so a future edit cannot
|
|
||||||
move it back onto the loop.
|
|
||||||
- **On the loop, not offloaded** → **FIX+ANCHOR**: offload the production code
|
|
||||||
(your fix), then add an anchor that guards it.
|
|
||||||
- **Not actually exposed / acceptable** (rare: scanner false positive,
|
|
||||||
startup-only code) → **NO-ACTION**: record one line of why.
|
|
||||||
- **Cross-file caveat**: the scanner's async reachability is same-file only
|
|
||||||
(`ASYNC_REACHABLE_SAME_FILE`). If the candidate is a *sync helper*, check for
|
|
||||||
async callers in other files (codegraph or `git grep`) before deciding
|
|
||||||
NO-ACTION.
|
|
||||||
|
|
||||||
### Step 2 — Apply the fix, then re-scan (FIX+ANCHOR only)
|
|
||||||
|
|
||||||
Offload the blocking call in production code, then re-run the Step 0 scan and
|
|
||||||
confirm the candidate no longer appears. If the offloaded call sits in a
|
|
||||||
`finally` / cleanup path, keep it best-effort and bounded (swallow-and-log,
|
|
||||||
`asyncio.wait_for`) so a failing or hung cleanup cannot mask the primary
|
|
||||||
exception. Match by the stable key
|
|
||||||
**(path, function, symbol)** — line numbers shift after edits, so never
|
|
||||||
compare by line.
|
|
||||||
|
|
||||||
- The finding must disappear. If it still shows, the fix did not remove the
|
|
||||||
blocking pattern (e.g. the call is still a direct call, not offloaded) —
|
|
||||||
go back before touching any test.
|
|
||||||
- GUARD / NO-ACTION routes skip this step: a residual finding there is
|
|
||||||
*expected* (the raw call still exists inside a sync helper with the offload
|
|
||||||
at the caller, or the exposure was judged acceptable).
|
|
||||||
|
|
||||||
This is pattern-level feedback in seconds; it complements but never replaces
|
|
||||||
Step 5 — only the runtime gate proves the event loop is actually protected.
|
|
||||||
|
|
||||||
### Step 3 — Check existing anchors
|
|
||||||
|
|
||||||
Look in `backend/tests/blocking_io/` for a test that drives the production async
|
|
||||||
entry point reaching this candidate's branch.
|
|
||||||
|
|
||||||
- Covers this branch already → go to Step 5 (re-verify teeth).
|
|
||||||
- Covers the entry point but not this branch (e.g. happy path covered,
|
|
||||||
cleanup/404/409 not) → **extend** that anchor.
|
|
||||||
- None → create one from `templates/anchor.template.py`.
|
|
||||||
|
|
||||||
### Step 4 — Generate / extend the anchor
|
|
||||||
|
|
||||||
Follow `references/good-anchor-rules.md`. Drive the *specific* branch (e.g. force
|
|
||||||
the create failure that hits the cleanup `shutil.rmtree`). Never bypass the
|
|
||||||
blocking surface with a test-only `asyncio.to_thread` wrapper.
|
|
||||||
|
|
||||||
### Step 5 — Verify teeth (mandatory; also the anchor-vs-rule discriminator)
|
|
||||||
|
|
||||||
1. Reintroduce the block (GUARD: temporarily revert the offload; FIX+ANCHOR: run
|
|
||||||
against the pre-fix code).
|
|
||||||
2. Run `cd backend && make test-blocking-io` (or target the one test). It **must
|
|
||||||
go RED**.
|
|
||||||
3. Restore the fix. It **must go GREEN**.
|
|
||||||
|
|
||||||
A real block that stays GREEN means Blockbuster has no rule for that
|
|
||||||
primitive — that is the **RULE** route; see `references/good-anchor-rules.md`
|
|
||||||
for the admission criteria before adding one.
|
|
||||||
|
|
||||||
### Step 6 — Deliver
|
|
||||||
|
|
||||||
Commit the anchor(s) with your change; `make test-blocking-io` green. In the PR,
|
|
||||||
note: candidates found, each disposition, the re-scan result (Step 2), and
|
|
||||||
the teeth evidence (red→green). Include the reason for any NO-ACTION. A new
|
|
||||||
Blockbuster rule, if any, goes in its own commit with the evidence from Step 5.
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
# Good anchor rules + teeth (blocking-IO fill)
|
|
||||||
|
|
||||||
Distilled from `backend/docs/BLOCKING_IO_DETECTION.md`. An anchor lives in
|
|
||||||
`backend/tests/blocking_io/`; the suite's conftest runs each test under the
|
|
||||||
strict Blockbuster gate scoped to `app.*` / `deerflow.*`.
|
|
||||||
|
|
||||||
The examples in this file and in `templates/` are all filesystem-flavored.
|
|
||||||
They demonstrate how to *write* the test, not what the SOP covers: the same
|
|
||||||
rules apply to every category the detector reports (FILE_IO, HTTP,
|
|
||||||
SUBPROCESS, SLEEP), and the acceptance criterion is always the teeth check
|
|
||||||
below — never similarity to an example.
|
|
||||||
|
|
||||||
## A good anchor
|
|
||||||
|
|
||||||
- Calls the **real production async entry point** — not a low-level helper,
|
|
||||||
unless that helper *is* the entry point production executes.
|
|
||||||
- Does **not** bypass the blocking surface with a test-only
|
|
||||||
`asyncio.to_thread` / `run_in_executor` wrapper.
|
|
||||||
- Uses **real local filesystem** inputs when the bug shape is filesystem IO.
|
|
||||||
- Mocks **only** the external dependency boundary (network service, third-party
|
|
||||||
saver), never the offload being guarded.
|
|
||||||
- Drives the **specific branch** you are protecting (error / cleanup / 404 /
|
|
||||||
409), not just the happy path.
|
|
||||||
|
|
||||||
## Teeth (the acceptance test)
|
|
||||||
|
|
||||||
An anchor only counts if the gate actually fires when the code blocks:
|
|
||||||
|
|
||||||
1. Reintroduce the block (revert the offload, or run pre-fix code).
|
|
||||||
2. `cd backend && make test-blocking-io` → the anchor **must fail** (RED).
|
|
||||||
3. Restore the fix → the anchor **must pass** (GREEN).
|
|
||||||
|
|
||||||
A green-on-happy-path anchor with no proven red is fake coverage. Don't ship it.
|
|
||||||
|
|
||||||
## The RULE route (rare; strict admission criteria)
|
|
||||||
|
|
||||||
Blockbuster's built-in rules cover the common blocking primitives well. The
|
|
||||||
two deliberate openings in this SOP are:
|
|
||||||
|
|
||||||
1. **Coverage opening** (the normal case): the rules already see the
|
|
||||||
primitive — you only need an anchor so runtime detection executes the real
|
|
||||||
business path and CI prevents regression.
|
|
||||||
2. **Rule opening** (rare): you reintroduced a *real* block and the gate
|
|
||||||
stayed GREEN — Blockbuster has no rule for that primitive.
|
|
||||||
|
|
||||||
A project rule lives in `_PROJECT_BLOCKING_RULES` inside
|
|
||||||
`backend/tests/support/detectors/blocking_io_runtime.py` and changes detection
|
|
||||||
for the **entire** blocking-IO suite — global blast radius. Admission criteria
|
|
||||||
for adding one:
|
|
||||||
|
|
||||||
- You have the **fails-to-fail anchor** as evidence: a good anchor (per the
|
|
||||||
rules above) that drives a genuinely blocking path and stays green. No
|
|
||||||
evidence, no rule.
|
|
||||||
- The primitive is a real blocking call (verified against its implementation
|
|
||||||
or docs), not a false positive of the static detector.
|
|
||||||
- The rule ships in its **own commit**, naming the primitive, the anchor that
|
|
||||||
exposed the gap, and the suite-wide impact. Run the full
|
|
||||||
`make test-blocking-io` suite after adding it — a new rule can turn other
|
|
||||||
previously-green tests red, and each such red is either a real latent bug
|
|
||||||
(fix it) or rule overreach (narrow the rule).
|
|
||||||
- If you are not in a position to own that blast radius (e.g. external
|
|
||||||
contributor), escalate to a maintainer with the evidence instead.
|
|
||||||
|
|
||||||
**Never add a runtime rule just because a path is untested** — that case needs
|
|
||||||
an anchor, not a rule.
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
# SOP skeleton (generic shape — extraction seam)
|
|
||||||
|
|
||||||
This is the domain-agnostic shape the blocking-IO skill instantiates. It exists
|
|
||||||
so a second detector/gate domain can reuse the flow without copying it. Do not
|
|
||||||
add machinery for that until a second domain actually appears (YAGNI).
|
|
||||||
|
|
||||||
A domain provides:
|
|
||||||
- a **static detector** that can scan a diff (or the whole tree) and emit
|
|
||||||
located candidates,
|
|
||||||
- a **CI gate** that fails when the bad pattern executes,
|
|
||||||
- a **test location** for guard tests,
|
|
||||||
- **good-test rules** for that gate,
|
|
||||||
- a **teeth definition** (how to make the gate fire on purpose).
|
|
||||||
|
|
||||||
Steps:
|
|
||||||
1. **Scope (deterministic):** intersect the diff's added lines with the
|
|
||||||
detector's findings → candidates this change introduced/touched. (Or, in
|
|
||||||
triage mode, take the full finding list ordered by priority.)
|
|
||||||
2. **Judge (router):** per candidate — guard existing fix / fix + guard /
|
|
||||||
no-action / rule (the gate cannot see the primitive).
|
|
||||||
3. **Fix + re-scope (fixes only):** apply the fix, re-run the detector; the
|
|
||||||
fixed candidate must vanish from the findings (match by a stable key, not
|
|
||||||
line numbers). Pattern-level feedback in seconds — complements, never
|
|
||||||
replaces, step 5.
|
|
||||||
4. **Generate:** draft or extend a guard test per the good-test rules, driving
|
|
||||||
the specific branch.
|
|
||||||
5. **Verify teeth:** make the bad pattern happen → gate must fail; restore →
|
|
||||||
gate must pass. A pattern that stays green while genuinely bad is the
|
|
||||||
"rule" signal, not a coverage success.
|
|
||||||
6. **Deliver:** commit the verified guard test; any gate-rule change ships in
|
|
||||||
its own commit with the fails-to-fail evidence attached.
|
|
||||||
|
|
||||||
To add a domain: supply a new fill doc (like `good-anchor-rules.md`) + detector,
|
|
||||||
and promote this file into a parent skill the instances point at.
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
"""Template: a tests/blocking_io/ runtime anchor.
|
|
||||||
|
|
||||||
Copy into backend/tests/blocking_io/test_<area>.py and adapt. The suite's
|
|
||||||
conftest already wraps every test here in the strict Blockbuster gate, so you do
|
|
||||||
NOT import or activate the detector — just drive the real async entry point.
|
|
||||||
|
|
||||||
Teeth check before you commit (see references/good-anchor-rules.md):
|
|
||||||
1. reintroduce the block -> `cd backend && make test-blocking-io` must FAIL
|
|
||||||
2. restore the fix -> it must PASS
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
# from app.<module> import <real_async_entry_point>
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.asyncio
|
|
||||||
|
|
||||||
|
|
||||||
async def test_<entry_point>_offloads_blocking_io_on_<branch>(tmp_path: Path) -> None:
|
|
||||||
# Arrange: real inputs at the boundary the code blocks on (FS -> tmp_path;
|
|
||||||
# HTTP/subprocess -> stub the external service). Mock ONLY the external
|
|
||||||
# boundary, never the offload under test.
|
|
||||||
|
|
||||||
# Act + Assert: call the REAL production async entry point and drive the
|
|
||||||
# specific branch you are guarding (e.g. force a failure to hit the cleanup
|
|
||||||
# path). If the entry point performs blocking IO on the loop, the gate fails.
|
|
||||||
# await <real_async_entry_point>(...)
|
|
||||||
raise NotImplementedError("Replace with the real async entry point call.")
|
|
||||||
@@ -21,7 +21,6 @@ INFOQUEST_API_KEY=your-infoquest-api-key
|
|||||||
# DEEPSEEK_API_KEY=your-deepseek-api-key
|
# DEEPSEEK_API_KEY=your-deepseek-api-key
|
||||||
# NOVITA_API_KEY=your-novita-api-key # OpenAI-compatible, see https://novita.ai
|
# NOVITA_API_KEY=your-novita-api-key # OpenAI-compatible, see https://novita.ai
|
||||||
# MINIMAX_API_KEY=your-minimax-api-key # OpenAI-compatible, see https://platform.minimax.io
|
# MINIMAX_API_KEY=your-minimax-api-key # OpenAI-compatible, see https://platform.minimax.io
|
||||||
# STEPFUN_API_KEY=your-stepfun-api-key # OpenAI-compatible, see https://platform.stepfun.com
|
|
||||||
# VLLM_API_KEY=your-vllm-api-key # OpenAI-compatible
|
# VLLM_API_KEY=your-vllm-api-key # OpenAI-compatible
|
||||||
# FEISHU_APP_ID=your-feishu-app-id
|
# FEISHU_APP_ID=your-feishu-app-id
|
||||||
# FEISHU_APP_SECRET=your-feishu-app-secret
|
# FEISHU_APP_SECRET=your-feishu-app-secret
|
||||||
|
|||||||
@@ -0,0 +1,72 @@
|
|||||||
|
# Path-based PR auto-labeling config for actions/labeler@v5.
|
||||||
|
# Each key is a label (must exist — see .github/labels.yml); the globs decide
|
||||||
|
# when it is applied. A PR can match several areas, which is expected.
|
||||||
|
|
||||||
|
"area:frontend":
|
||||||
|
- changed-files:
|
||||||
|
- any-glob-to-any-file:
|
||||||
|
- "frontend/**"
|
||||||
|
|
||||||
|
"area:backend":
|
||||||
|
- changed-files:
|
||||||
|
- any-glob-to-any-file:
|
||||||
|
- "backend/app/**"
|
||||||
|
- "backend/packages/harness/deerflow/runtime/**"
|
||||||
|
- "backend/packages/harness/deerflow/persistence/**"
|
||||||
|
- "backend/packages/harness/deerflow/config/**"
|
||||||
|
- "backend/packages/harness/deerflow/tools/**"
|
||||||
|
- "backend/packages/harness/deerflow/guardrails/**"
|
||||||
|
- "backend/packages/harness/deerflow/tracing/**"
|
||||||
|
- "backend/packages/harness/deerflow/models/**"
|
||||||
|
- "backend/packages/harness/deerflow/utils/**"
|
||||||
|
- "backend/packages/harness/deerflow/uploads/**"
|
||||||
|
|
||||||
|
"area:agents":
|
||||||
|
- changed-files:
|
||||||
|
- any-glob-to-any-file:
|
||||||
|
- "backend/packages/harness/deerflow/agents/**"
|
||||||
|
- "backend/packages/harness/deerflow/subagents/**"
|
||||||
|
- "backend/packages/harness/deerflow/reflection/**"
|
||||||
|
- "backend/langgraph.json"
|
||||||
|
- "backend/**/prompts/**"
|
||||||
|
|
||||||
|
"area:sandbox":
|
||||||
|
- changed-files:
|
||||||
|
- any-glob-to-any-file:
|
||||||
|
- "docker/**"
|
||||||
|
- "backend/packages/harness/deerflow/sandbox/**"
|
||||||
|
- "backend/Dockerfile"
|
||||||
|
- "frontend/Dockerfile"
|
||||||
|
|
||||||
|
"area:skills":
|
||||||
|
- changed-files:
|
||||||
|
- any-glob-to-any-file:
|
||||||
|
- "skills/**"
|
||||||
|
- "backend/packages/harness/deerflow/skills/**"
|
||||||
|
- "frontend/src/core/skills/**"
|
||||||
|
|
||||||
|
"area:mcp":
|
||||||
|
- changed-files:
|
||||||
|
- any-glob-to-any-file:
|
||||||
|
- "backend/packages/harness/deerflow/mcp/**"
|
||||||
|
- "frontend/src/core/mcp/**"
|
||||||
|
|
||||||
|
"area:ci":
|
||||||
|
- changed-files:
|
||||||
|
- any-glob-to-any-file:
|
||||||
|
- ".github/**"
|
||||||
|
- "scripts/**"
|
||||||
|
|
||||||
|
"area:docs":
|
||||||
|
- changed-files:
|
||||||
|
- any-glob-to-any-file:
|
||||||
|
- "docs/**"
|
||||||
|
- "**/*.md"
|
||||||
|
|
||||||
|
"area:deps":
|
||||||
|
- changed-files:
|
||||||
|
- any-glob-to-any-file:
|
||||||
|
- "backend/pyproject.toml"
|
||||||
|
- "backend/uv.lock"
|
||||||
|
- "frontend/package.json"
|
||||||
|
- "frontend/pnpm-lock.yaml"
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
name: Issue Triage
|
||||||
|
|
||||||
|
# Ensures every newly opened issue carries `needs-triage`, even blank or
|
||||||
|
# API-created ones that bypass the issue templates. Creates the label if it is
|
||||||
|
# somehow missing, so the workflow is self-healing.
|
||||||
|
|
||||||
|
on:
|
||||||
|
issues:
|
||||||
|
types: [opened]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
issues: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
needs-triage:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Add needs-triage label
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const { owner, repo } = context.repo;
|
||||||
|
const issue_number = context.payload.issue.number;
|
||||||
|
|
||||||
|
const current = (context.payload.issue.labels || []).map(l => l.name);
|
||||||
|
if (current.includes('needs-triage')) {
|
||||||
|
core.info('Issue already has needs-triage; nothing to do.');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Self-heal: create the label if it does not exist yet.
|
||||||
|
try {
|
||||||
|
await github.rest.issues.createLabel({
|
||||||
|
owner, repo, name: 'needs-triage', color: 'fef2c0',
|
||||||
|
description: 'Awaiting maintainer triage',
|
||||||
|
});
|
||||||
|
} catch (e) {
|
||||||
|
if (e.status !== 422) throw e; // 422 = already exists
|
||||||
|
}
|
||||||
|
|
||||||
|
await github.rest.issues.addLabels({
|
||||||
|
owner, repo, issue_number, labels: ['needs-triage'],
|
||||||
|
});
|
||||||
|
core.info(`Added needs-triage to #${issue_number}.`);
|
||||||
@@ -10,7 +10,7 @@ permissions:
|
|||||||
contents: read
|
contents: read
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
lint-backend:
|
lint:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v6
|
- uses: actions/checkout@v6
|
||||||
|
|||||||
@@ -0,0 +1,28 @@
|
|||||||
|
name: PR Labeler
|
||||||
|
|
||||||
|
# Applies area:* labels based on which files a PR changes (see .github/labeler.yml).
|
||||||
|
# Uses pull_request_target so it also works on fork PRs. SAFE: actions/labeler
|
||||||
|
# only reads the changed-file list via the API — it never checks out or runs PR code.
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request_target:
|
||||||
|
types: [opened, synchronize, reopened, ready_for_review]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: pr-labeler-${{ github.event.pull_request.number }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
label:
|
||||||
|
if: github.event.pull_request.draft == false
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Apply area labels
|
||||||
|
uses: actions/labeler@v5
|
||||||
|
with:
|
||||||
|
configuration-path: .github/labeler.yml
|
||||||
|
sync-labels: true
|
||||||
@@ -0,0 +1,164 @@
|
|||||||
|
name: PR Triage
|
||||||
|
|
||||||
|
# Two responsibilities, both pure-metadata (no PR code is checked out or run):
|
||||||
|
# 1. On open/sync: apply size/* + risk:* labels, and needs-validation when the
|
||||||
|
# PR touches the front/back contract surface (backend API, SSE, agents, or
|
||||||
|
# the frontend streaming client). A `skip-validation` label opts out.
|
||||||
|
# 2. On maintainer review: apply the `reviewing` label.
|
||||||
|
#
|
||||||
|
# All labels are managed within their own namespace — labels outside size/*,
|
||||||
|
# risk:*, needs-validation and reviewing are never touched here.
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request_target:
|
||||||
|
types: [opened, synchronize, reopened, ready_for_review]
|
||||||
|
pull_request_review:
|
||||||
|
types: [submitted]
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: pr-triage-${{ github.event.pull_request.number }}
|
||||||
|
cancel-in-progress: false
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
size-and-risk:
|
||||||
|
if: github.event_name == 'pull_request_target' && github.event.pull_request.draft == false
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Label size, risk and validation need
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const pr = context.payload.pull_request;
|
||||||
|
const { owner, repo } = context.repo;
|
||||||
|
const prNumber = pr.number;
|
||||||
|
|
||||||
|
// ---- size, from additions + deletions ----
|
||||||
|
const churn = (pr.additions || 0) + (pr.deletions || 0);
|
||||||
|
const sizeLabel =
|
||||||
|
churn < 20 ? 'size/XS' :
|
||||||
|
churn < 100 ? 'size/S' :
|
||||||
|
churn < 300 ? 'size/M' :
|
||||||
|
churn < 700 ? 'size/L' : 'size/XL';
|
||||||
|
|
||||||
|
// ---- changed paths ----
|
||||||
|
const files = await github.paginate(github.rest.pulls.listFiles, {
|
||||||
|
owner, repo, pull_number: prNumber, per_page: 100,
|
||||||
|
});
|
||||||
|
const paths = files.map(f => f.filename);
|
||||||
|
|
||||||
|
const matches = (re) => paths.some(p => re.test(p));
|
||||||
|
|
||||||
|
const docsOnly = paths.length > 0 && paths.every(p =>
|
||||||
|
/\.(md|mdx|txt)$/i.test(p) || p.startsWith('docs/') ||
|
||||||
|
/\.(png|jpe?g|gif|svg|webp|ico)$/i.test(p));
|
||||||
|
|
||||||
|
const highRisk = matches(
|
||||||
|
/^backend\/app\/gateway\//) || matches(
|
||||||
|
/^backend\/packages\/harness\/deerflow\/(agents|subagents|sandbox)\//) || matches(
|
||||||
|
/(^|\/)langgraph\.json$/) || matches(
|
||||||
|
/(^|\/)(auth|authz|security)/i) || matches(
|
||||||
|
/(pyproject\.toml|uv\.lock|package\.json|pnpm-lock\.yaml)$/) || matches(
|
||||||
|
/^docker\//) || matches(
|
||||||
|
/^\.github\/workflows\//);
|
||||||
|
|
||||||
|
const riskLabel = docsOnly ? 'risk:low' : (highRisk ? 'risk:high' : 'risk:medium');
|
||||||
|
|
||||||
|
// needs-validation: front/back contract surface
|
||||||
|
const contractSurface =
|
||||||
|
matches(/^backend\/app\/gateway\//) ||
|
||||||
|
matches(/^backend\/packages\/harness\/deerflow\/(agents|subagents)\//) ||
|
||||||
|
matches(/(^|\/)langgraph\.json$/) ||
|
||||||
|
matches(/^frontend\/src\/core\/(api|threads|messages)\//);
|
||||||
|
|
||||||
|
const current = (pr.labels || []).map(l => l.name);
|
||||||
|
const hasSkip = current.includes('skip-validation');
|
||||||
|
|
||||||
|
const desired = [sizeLabel, riskLabel];
|
||||||
|
if (contractSurface && !hasSkip) desired.push('needs-validation');
|
||||||
|
|
||||||
|
const managed = (name) =>
|
||||||
|
name.startsWith('size/') || name.startsWith('risk:') || name === 'needs-validation';
|
||||||
|
|
||||||
|
const toRemove = current.filter(l => managed(l) && !desired.includes(l));
|
||||||
|
const toAdd = desired.filter(l => !current.includes(l));
|
||||||
|
|
||||||
|
for (const name of toRemove) {
|
||||||
|
try {
|
||||||
|
await github.rest.issues.removeLabel({ owner, repo, issue_number: prNumber, name });
|
||||||
|
} catch (e) {
|
||||||
|
if (e.status !== 404) throw e;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (toAdd.length) {
|
||||||
|
await github.rest.issues.addLabels({ owner, repo, issue_number: prNumber, labels: toAdd });
|
||||||
|
}
|
||||||
|
core.info(`size=${sizeLabel} risk=${riskLabel} churn=${churn} ` +
|
||||||
|
`validation=${desired.includes('needs-validation')} ` +
|
||||||
|
`(+${toAdd.join(',') || '-'} / -${toRemove.join(',') || '-'})`);
|
||||||
|
|
||||||
|
first-time:
|
||||||
|
if: github.event_name == 'pull_request_target' && github.event.action == 'opened'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Label first-time contributors
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const pr = context.payload.pull_request;
|
||||||
|
const { owner, repo } = context.repo;
|
||||||
|
const assoc = pr.author_association;
|
||||||
|
const isBot = pr.user.type === 'Bot';
|
||||||
|
core.info(`author=${pr.user.login} association=${assoc} bot=${isBot}`);
|
||||||
|
|
||||||
|
// FIRST_TIME_CONTRIBUTOR = no prior merged commit to this repo;
|
||||||
|
// FIRST_TIMER = no prior commit anywhere on GitHub. Either counts.
|
||||||
|
if (isBot || !['FIRST_TIME_CONTRIBUTOR', 'FIRST_TIMER'].includes(assoc)) {
|
||||||
|
core.info('Not a first-time contributor; skipping.');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
await github.rest.issues.addLabels({
|
||||||
|
owner, repo, issue_number: pr.number, labels: ['first-time-contributor'],
|
||||||
|
});
|
||||||
|
core.info(`Added first-time-contributor to #${pr.number}.`);
|
||||||
|
|
||||||
|
reviewing:
|
||||||
|
if: github.event_name == 'pull_request_review'
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Add reviewing label for maintainer reviews
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const { owner, repo } = context.repo;
|
||||||
|
const prNumber = context.payload.pull_request.number;
|
||||||
|
const reviewer = context.payload.review.user.login;
|
||||||
|
|
||||||
|
const { data: perm } = await github.rest.repos.getCollaboratorPermissionLevel({
|
||||||
|
owner, repo, username: reviewer,
|
||||||
|
});
|
||||||
|
if (!['admin', 'write', 'maintain'].includes(perm.permission)) {
|
||||||
|
core.info(`Reviewer ${reviewer} (${perm.permission}) is not a maintainer; skipping.`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { data: labels } = await github.rest.issues.listLabelsOnIssue({
|
||||||
|
owner, repo, issue_number: prNumber,
|
||||||
|
});
|
||||||
|
if (labels.some(l => l.name === 'reviewing')) {
|
||||||
|
core.info('Already labeled reviewing; skipping.');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
await github.rest.issues.addLabels({
|
||||||
|
owner, repo, issue_number: prNumber, labels: ['reviewing'],
|
||||||
|
});
|
||||||
|
core.info(`Added "reviewing" (reviewer ${reviewer}).`);
|
||||||
|
} catch (e) {
|
||||||
|
// 403 is expected for review events on some fork PR contexts.
|
||||||
|
if (e.status === 403) core.info('No permission to label (expected on some fork PRs).');
|
||||||
|
else throw e;
|
||||||
|
}
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
name: Replay E2E (front-back contract)
|
|
||||||
|
|
||||||
# Guards the front-back contract via record/replay (no API key in CI):
|
|
||||||
# Layer 1 — backend golden: replay a recorded trace through the real gateway,
|
|
||||||
# assert the SSE event sequence matches the committed golden.
|
|
||||||
# Layer 2 — full-stack render: real Next.js frontend + real gateway (replay
|
|
||||||
# model) + Chromium; assert the replayed turns render in the browser.
|
|
||||||
# Triggered by changes on EITHER side of the contract so a backend change can no
|
|
||||||
# longer pass without the frontend-facing checks running.
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: ["main"]
|
|
||||||
paths:
|
|
||||||
- "frontend/**"
|
|
||||||
- "backend/app/gateway/**"
|
|
||||||
- "backend/packages/harness/**"
|
|
||||||
- "backend/tests/fixtures/replay/**"
|
|
||||||
- "backend/tests/replay_provider.py"
|
|
||||||
- "backend/tests/_replay_fixture.py"
|
|
||||||
- "backend/tests/seed_runs_router.py"
|
|
||||||
- "backend/tests/test_replay_golden.py"
|
|
||||||
- "backend/scripts/run_replay_gateway.py"
|
|
||||||
- ".github/workflows/replay-e2e.yml"
|
|
||||||
pull_request:
|
|
||||||
types: [opened, synchronize, reopened, ready_for_review]
|
|
||||||
paths:
|
|
||||||
- "frontend/**"
|
|
||||||
- "backend/app/gateway/**"
|
|
||||||
- "backend/packages/harness/**"
|
|
||||||
- "backend/tests/fixtures/replay/**"
|
|
||||||
- "backend/tests/replay_provider.py"
|
|
||||||
- "backend/tests/_replay_fixture.py"
|
|
||||||
- "backend/tests/seed_runs_router.py"
|
|
||||||
- "backend/tests/test_replay_golden.py"
|
|
||||||
- "backend/scripts/run_replay_gateway.py"
|
|
||||||
- ".github/workflows/replay-e2e.yml"
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: replay-e2e-${{ github.event.pull_request.number || github.ref }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
backend-replay-golden:
|
|
||||||
name: Layer 1 — backend golden (no API key)
|
|
||||||
if: github.event_name != 'pull_request' || github.event.pull_request.draft == false
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
timeout-minutes: 15
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v6
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v6
|
|
||||||
with:
|
|
||||||
python-version: "3.12"
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@v7
|
|
||||||
- name: Install backend dependencies
|
|
||||||
working-directory: backend
|
|
||||||
run: uv sync --group dev
|
|
||||||
- name: Replay golden (backend SSE contract)
|
|
||||||
working-directory: backend
|
|
||||||
run: PYTHONPATH=. uv run pytest tests/test_replay_golden.py -v
|
|
||||||
|
|
||||||
fullstack-replay-render:
|
|
||||||
name: Layer 2 — full-stack render (no API key)
|
|
||||||
if: github.event_name != 'pull_request' || github.event.pull_request.draft == false
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
timeout-minutes: 25
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v6
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v6
|
|
||||||
with:
|
|
||||||
python-version: "3.12"
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@v7
|
|
||||||
- name: Install backend dependencies (replay gateway)
|
|
||||||
working-directory: backend
|
|
||||||
run: uv sync --group dev
|
|
||||||
- name: Setup Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: "22"
|
|
||||||
- name: Enable Corepack
|
|
||||||
run: corepack enable
|
|
||||||
- name: Use pinned pnpm version
|
|
||||||
run: corepack prepare pnpm@10.26.2 --activate
|
|
||||||
- name: Install frontend dependencies
|
|
||||||
working-directory: frontend
|
|
||||||
run: pnpm install --frozen-lockfile
|
|
||||||
- name: Install Playwright Chromium
|
|
||||||
working-directory: frontend
|
|
||||||
run: npx playwright install chromium --with-deps
|
|
||||||
- name: Full-stack replay render (DOM assertions are the gate)
|
|
||||||
working-directory: frontend
|
|
||||||
run: pnpm exec playwright test -c playwright.real-backend.config.ts
|
|
||||||
- name: Upload report + render artifact
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
if: ${{ !cancelled() }}
|
|
||||||
with:
|
|
||||||
name: replay-render
|
|
||||||
path: |
|
|
||||||
frontend/playwright-report/
|
|
||||||
frontend/test-results/
|
|
||||||
retention-days: 7
|
|
||||||
@@ -1,223 +0,0 @@
|
|||||||
name: Triage
|
|
||||||
|
|
||||||
# One workflow for all event-driven PR/issue labeling. Replaces the former
|
|
||||||
# pr-labeler / pr-triage / issue-triage workflows (and drops actions/labeler).
|
|
||||||
#
|
|
||||||
# Design notes:
|
|
||||||
# * All jobs are pure-metadata: they read changed-file lists / PR fields / the
|
|
||||||
# review payload via the API and write labels. PR code is NEVER checked out
|
|
||||||
# or executed, so pull_request_target is safe here.
|
|
||||||
# * Each job only reconciles labels in namespaces IT owns
|
|
||||||
# (area:* / size/* / risk:* / needs-validation). It never touches labels
|
|
||||||
# applied by maintainers or other tools (bug, priority, etc.). first-time-
|
|
||||||
# contributor and reviewing are add-only.
|
|
||||||
# * State is read LIVE (listFiles + listLabelsOnIssue) at run time, not from
|
|
||||||
# the (stale) event payload, so rapid synchronize events converge instead
|
|
||||||
# of thrashing.
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request_target:
|
|
||||||
types: [opened, synchronize, reopened, ready_for_review]
|
|
||||||
pull_request_review:
|
|
||||||
types: [submitted]
|
|
||||||
issues:
|
|
||||||
types: [opened]
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
pull-requests: write
|
|
||||||
issues: write
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
# ── PR: area / size / risk / needs-validation / first-time ─────────────────
|
|
||||||
pr-labels:
|
|
||||||
if: github.event_name == 'pull_request_target' && github.event.pull_request.draft == false
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
concurrency:
|
|
||||||
group: triage-pr-${{ github.event.pull_request.number }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
steps:
|
|
||||||
- name: Apply PR labels from live state
|
|
||||||
uses: actions/github-script@v8
|
|
||||||
with:
|
|
||||||
script: |
|
|
||||||
const pr = context.payload.pull_request;
|
|
||||||
const { owner, repo } = context.repo;
|
|
||||||
const num = pr.number;
|
|
||||||
|
|
||||||
// ---- live changed files ----
|
|
||||||
const files = await github.paginate(github.rest.pulls.listFiles, {
|
|
||||||
owner, repo, pull_number: num, per_page: 100,
|
|
||||||
});
|
|
||||||
const paths = files.map(f => f.filename);
|
|
||||||
const m = (re) => paths.some(p => re.test(p));
|
|
||||||
|
|
||||||
// ---- area: replaces .github/labeler.yml (path -> area) ----
|
|
||||||
const AREA_RULES = [
|
|
||||||
['area:frontend', [/^frontend\//]],
|
|
||||||
['area:backend', [/^backend\/app\//, /^backend\/packages\/harness\/deerflow\/(runtime|persistence|config|tools|guardrails|tracing|models|utils|uploads)\//]],
|
|
||||||
['area:agents', [/^backend\/packages\/harness\/deerflow\/(agents|subagents|reflection)\//, /(^|\/)langgraph\.json$/, /^backend\/.*\/prompts\//]],
|
|
||||||
['area:sandbox', [/^docker\//, /^backend\/packages\/harness\/deerflow\/sandbox\//, /(^|\/)Dockerfile$/]],
|
|
||||||
['area:skills', [/^skills\//, /^backend\/packages\/harness\/deerflow\/skills\//, /^frontend\/src\/core\/skills\//]],
|
|
||||||
['area:mcp', [/^backend\/packages\/harness\/deerflow\/mcp\//, /^frontend\/src\/core\/mcp\//]],
|
|
||||||
['area:ci', [/^\.github\//, /^scripts\//]],
|
|
||||||
['area:docs', [/^docs\//, /\.mdx?$/]],
|
|
||||||
['area:deps', [/(^|\/)(pyproject\.toml|uv\.lock|package\.json|pnpm-lock\.yaml)$/]],
|
|
||||||
];
|
|
||||||
const areaLabels = AREA_RULES
|
|
||||||
.filter(([, res]) => res.some(re => m(re)))
|
|
||||||
.map(([label]) => label);
|
|
||||||
|
|
||||||
// ---- size: additions+deletions, excluding lockfiles/snapshots ----
|
|
||||||
const EXCLUDE_SIZE = /(^|\/)(uv\.lock|pnpm-lock\.yaml|package-lock\.json)$|\.snap$/;
|
|
||||||
const churn = files
|
|
||||||
.filter(f => !EXCLUDE_SIZE.test(f.filename))
|
|
||||||
.reduce((s, f) => s + (f.additions || 0) + (f.deletions || 0), 0);
|
|
||||||
const sizeLabel =
|
|
||||||
churn < 20 ? 'size/XS' :
|
|
||||||
churn < 100 ? 'size/S' :
|
|
||||||
churn < 300 ? 'size/M' :
|
|
||||||
churn < 700 ? 'size/L' : 'size/XL';
|
|
||||||
|
|
||||||
// ---- risk ----
|
|
||||||
const docsOnly = paths.length > 0 && paths.every(p =>
|
|
||||||
/\.(md|mdx|txt)$/i.test(p) || p.startsWith('docs/') ||
|
|
||||||
/\.(png|jpe?g|gif|svg|webp|ico)$/i.test(p));
|
|
||||||
const highRisk =
|
|
||||||
m(/^backend\/app\/gateway\//) ||
|
|
||||||
m(/^backend\/packages\/harness\/deerflow\/(agents|subagents|sandbox)\//) ||
|
|
||||||
m(/(^|\/)langgraph\.json$/) ||
|
|
||||||
m(/(^|\/)(auth|authz|security)/i) ||
|
|
||||||
m(/(pyproject\.toml|uv\.lock|package\.json|pnpm-lock\.yaml)$/) ||
|
|
||||||
m(/^docker\//) ||
|
|
||||||
m(/^\.github\/workflows\//);
|
|
||||||
const riskLabel = docsOnly ? 'risk:low' : (highRisk ? 'risk:high' : 'risk:medium');
|
|
||||||
|
|
||||||
// ---- needs-validation: front/back contract surface ----
|
|
||||||
const contract =
|
|
||||||
m(/^backend\/app\/gateway\//) ||
|
|
||||||
m(/^backend\/packages\/harness\/deerflow\/(agents|subagents)\//) ||
|
|
||||||
m(/(^|\/)langgraph\.json$/) ||
|
|
||||||
m(/^frontend\/src\/core\/(api|threads|messages)\//);
|
|
||||||
|
|
||||||
// ---- live current labels (NOT the stale event payload) ----
|
|
||||||
const current = (await github.paginate(github.rest.issues.listLabelsOnIssue, {
|
|
||||||
owner, repo, issue_number: num, per_page: 100,
|
|
||||||
})).map(l => l.name);
|
|
||||||
const hasSkip = current.includes('skip-validation');
|
|
||||||
|
|
||||||
// Reconcile ONLY namespaces we own; never touch others.
|
|
||||||
const owned = (n) =>
|
|
||||||
n.startsWith('area:') || n.startsWith('size/') ||
|
|
||||||
n.startsWith('risk:') || n === 'needs-validation';
|
|
||||||
const desired = new Set([...areaLabels, sizeLabel, riskLabel]);
|
|
||||||
if (contract && !hasSkip) desired.add('needs-validation');
|
|
||||||
|
|
||||||
const toRemove = current.filter(n => owned(n) && !desired.has(n));
|
|
||||||
const toAdd = [...desired].filter(n => !current.includes(n));
|
|
||||||
|
|
||||||
// first-time-contributor: add-only, on opened, real users only.
|
|
||||||
if (context.payload.action === 'opened' &&
|
|
||||||
pr.user.type === 'User' &&
|
|
||||||
['FIRST_TIME_CONTRIBUTOR', 'FIRST_TIMER'].includes(pr.author_association) &&
|
|
||||||
!current.includes('first-time-contributor')) {
|
|
||||||
toAdd.push('first-time-contributor');
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const name of toRemove) {
|
|
||||||
try {
|
|
||||||
await github.rest.issues.removeLabel({ owner, repo, issue_number: num, name });
|
|
||||||
} catch (e) {
|
|
||||||
if (e.status !== 404) throw e;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (toAdd.length) {
|
|
||||||
await github.rest.issues.addLabels({ owner, repo, issue_number: num, labels: toAdd });
|
|
||||||
}
|
|
||||||
core.info(`area=[${areaLabels.join(',')}] ${sizeLabel} ${riskLabel} churn=${churn} ` +
|
|
||||||
`validation=${desired.has('needs-validation')} ` +
|
|
||||||
`(+${toAdd.join(',') || '-'} / -${toRemove.join(',') || '-'})`);
|
|
||||||
|
|
||||||
# ── PR: reviewing label on a maintainer's human review ─────────────────────
|
|
||||||
reviewing:
|
|
||||||
if: github.event_name == 'pull_request_review'
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
concurrency:
|
|
||||||
group: triage-review-${{ github.event.pull_request.number }}
|
|
||||||
cancel-in-progress: false
|
|
||||||
steps:
|
|
||||||
- name: Add reviewing label for maintainer reviews
|
|
||||||
uses: actions/github-script@v8
|
|
||||||
with:
|
|
||||||
script: |
|
|
||||||
const { owner, repo } = context.repo;
|
|
||||||
const num = context.payload.pull_request.number;
|
|
||||||
const review = context.payload.review;
|
|
||||||
const assoc = review.author_association; // payload field; no API call
|
|
||||||
const type = review.user && review.user.type;
|
|
||||||
|
|
||||||
// author_association is NONE for every automated reviewer
|
|
||||||
// (Copilot, CodeRabbit, Codex, Sourcery, ...), so this allowlist
|
|
||||||
// drops them all without a denylist — and never calls the
|
|
||||||
// collaborators API that 404s on "Copilot is not a user".
|
|
||||||
// user.type === 'User' guards the rare bot-added-as-collaborator case.
|
|
||||||
if (!['OWNER', 'MEMBER', 'COLLABORATOR'].includes(assoc) || type !== 'User') {
|
|
||||||
core.info(`reviewer ${review.user && review.user.login} assoc=${assoc} type=${type}; skipping.`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const labels = (await github.paginate(github.rest.issues.listLabelsOnIssue, {
|
|
||||||
owner, repo, issue_number: num, per_page: 100,
|
|
||||||
})).map(l => l.name);
|
|
||||||
if (labels.includes('reviewing')) {
|
|
||||||
core.info('Already labeled reviewing; skipping.');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
await github.rest.issues.addLabels({
|
|
||||||
owner, repo, issue_number: num, labels: ['reviewing'],
|
|
||||||
});
|
|
||||||
core.info('Added "reviewing".');
|
|
||||||
} catch (e) {
|
|
||||||
if (e.status === 403) core.info('No permission to label (expected on some fork PRs).');
|
|
||||||
else throw e;
|
|
||||||
}
|
|
||||||
|
|
||||||
# ── Issue: needs-triage on every new issue ────────────────────────────────
|
|
||||||
issue-triage:
|
|
||||||
if: github.event_name == 'issues'
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
concurrency:
|
|
||||||
group: triage-issue-${{ github.event.issue.number }}
|
|
||||||
cancel-in-progress: false
|
|
||||||
steps:
|
|
||||||
- name: Add needs-triage label
|
|
||||||
uses: actions/github-script@v8
|
|
||||||
with:
|
|
||||||
script: |
|
|
||||||
const { owner, repo } = context.repo;
|
|
||||||
const issue_number = context.payload.issue.number;
|
|
||||||
|
|
||||||
// Read live labels (not the event payload) so labels added at creation
|
|
||||||
// time via the API or by another automation are seen — consistent with
|
|
||||||
// the live-state reads in the PR jobs above.
|
|
||||||
const current = (await github.paginate(github.rest.issues.listLabelsOnIssue, {
|
|
||||||
owner, repo, issue_number, per_page: 100,
|
|
||||||
})).map(l => l.name);
|
|
||||||
if (current.includes('needs-triage')) {
|
|
||||||
core.info('Issue already has needs-triage; nothing to do.');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// Self-heal: create the label if it does not exist yet.
|
|
||||||
try {
|
|
||||||
await github.rest.issues.createLabel({
|
|
||||||
owner, repo, name: 'needs-triage', color: 'fef2c0',
|
|
||||||
description: 'Awaiting maintainer triage',
|
|
||||||
});
|
|
||||||
} catch (e) {
|
|
||||||
if (e.status !== 422) throw e; // 422 = already exists
|
|
||||||
}
|
|
||||||
await github.rest.issues.addLabels({
|
|
||||||
owner, repo, issue_number, labels: ['needs-triage'],
|
|
||||||
});
|
|
||||||
core.info(`Added needs-triage to #${issue_number}.`);
|
|
||||||
@@ -247,9 +247,6 @@ Access: http://localhost:2026
|
|||||||
|
|
||||||
The unified nginx endpoint is same-origin by default and does not emit browser CORS headers. If you run a split-origin or port-forwarded browser client, set `GATEWAY_CORS_ORIGINS` to comma-separated exact origins such as `http://localhost:3000`; the Gateway then applies the CORS allowlist and matching CSRF origin checks.
|
The unified nginx endpoint is same-origin by default and does not emit browser CORS headers. If you run a split-origin or port-forwarded browser client, set `GATEWAY_CORS_ORIGINS` to comma-separated exact origins such as `http://localhost:3000`; the Gateway then applies the CORS allowlist and matching CSRF origin checks.
|
||||||
|
|
||||||
> [!IMPORTANT]
|
|
||||||
> The Gateway holds run state (RunManager and the stream bridge) in process, so production defaults to a single Gateway worker (`GATEWAY_WORKERS=1`). Raising the worker count without a shared cross-worker stream bridge — which is not yet available — breaks run cancellation, SSE reconnects, request de-duplication, and IM channels, because nginx uses no sticky sessions and each worker keeps its own run state. Scale a single worker up with more CPU/RAM (or move the database and sandbox onto dedicated tiers) instead of raising `GATEWAY_WORKERS`.
|
|
||||||
|
|
||||||
See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
|
See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
|
||||||
|
|
||||||
#### Option 2: Local Development
|
#### Option 2: Local Development
|
||||||
@@ -343,8 +340,6 @@ See the [MCP Server Guide](backend/docs/MCP_SERVER.md) for detailed instructions
|
|||||||
|
|
||||||
DeerFlow supports receiving tasks from messaging apps. Channels auto-start when configured — no public IP required for any of them.
|
DeerFlow supports receiving tasks from messaging apps. Channels auto-start when configured — no public IP required for any of them.
|
||||||
|
|
||||||
DeerFlow can also expose user-owned IM channel connections in the workspace UI. When `channel_connections` is enabled, logged-in users can bind Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, or WeCom from the sidebar / Settings > Channels. It reuses the existing outbound `channels.*` transports, so no public IP or provider callback URL is required. Incoming IM messages then run under the connected DeerFlow user account. See [IM Channel Connections](backend/docs/IM_CHANNEL_CONNECTIONS.md) for setup and security notes.
|
|
||||||
|
|
||||||
| Channel | Transport | Difficulty |
|
| Channel | Transport | Difficulty |
|
||||||
|---------|-----------|------------|
|
|---------|-----------|------------|
|
||||||
| Telegram | Bot API (long-polling) | Easy |
|
| Telegram | Bot API (long-polling) | Easy |
|
||||||
@@ -590,8 +585,6 @@ A standard Agent Skill is a structured capability module — a Markdown file tha
|
|||||||
|
|
||||||
Skills are loaded progressively — only when the task needs them, not all at once. This keeps the context window lean and makes DeerFlow work well even with token-sensitive models.
|
Skills are loaded progressively — only when the task needs them, not all at once. This keeps the context window lean and makes DeerFlow work well even with token-sensitive models.
|
||||||
|
|
||||||
Users can explicitly activate an enabled skill for a single turn by starting the request with `/skill-name`, for example `/data-analysis analyze uploads/foo.csv`. DeerFlow loads that skill's `SKILL.md` as hidden current-turn context while leaving the base prompt limited to skill metadata. Slash activation respects disabled skills, custom-agent skill whitelists, and existing channel commands such as `/new` and `/help`.
|
|
||||||
|
|
||||||
When you install `.skill` archives through the Gateway, DeerFlow accepts standard optional frontmatter metadata such as `version`, `author`, and `compatibility` instead of rejecting otherwise valid external skills.
|
When you install `.skill` archives through the Gateway, DeerFlow accepts standard optional frontmatter metadata such as `version`, `author`, and `compatibility` instead of rejecting otherwise valid external skills.
|
||||||
|
|
||||||
Tools follow the same philosophy. DeerFlow comes with a core toolset — web search, web fetch, file operations, bash execution — and supports custom tools via MCP servers and Python functions. Swap anything. Add anything.
|
Tools follow the same philosophy. DeerFlow comes with a core toolset — web search, web fetch, file operations, bash execution — and supports custom tools via MCP servers and Python functions. Swap anything. Add anything.
|
||||||
|
|||||||
@@ -24,10 +24,5 @@ config.yaml
|
|||||||
# Langgraph
|
# Langgraph
|
||||||
.langgraph_api
|
.langgraph_api
|
||||||
|
|
||||||
# Sandbox runtime working dir — pre-created and excluded from uvicorn reload
|
|
||||||
# (scripts/serve.sh, docker/dev-entrypoint.sh). Anchored so it does not match
|
|
||||||
# the source package backend/packages/harness/deerflow/sandbox/.
|
|
||||||
/sandbox/
|
|
||||||
|
|
||||||
# Claude Code settings
|
# Claude Code settings
|
||||||
.claude/settings.local.json
|
.claude/settings.local.json
|
||||||
|
|||||||
+26
-58
@@ -112,14 +112,6 @@ calls are resolved by function name, so duplicate helper names in one file can
|
|||||||
conservatively over-report async reachability. It is intentionally
|
conservatively over-report async reachability. It is intentionally
|
||||||
informational and is not run from CI in this round.
|
informational and is not run from CI in this round.
|
||||||
|
|
||||||
For a diff-scoped view of the same findings, `scripts/scan_changed_blocking_io.py`
|
|
||||||
(repo root) reports findings on the added lines of `git diff <base>...HEAD`
|
|
||||||
plus findings new versus the merge base (so a new async caller exposing an
|
|
||||||
untouched sync helper in the same file is still reported) — used by the
|
|
||||||
`blocking-io-guard` skill (`.agent/skills/blocking-io-guard/`) as the
|
|
||||||
deterministic scope step before routing each candidate to a fix and/or a
|
|
||||||
`tests/blocking_io/` runtime anchor.
|
|
||||||
|
|
||||||
Regression tests related to Docker/provisioner behavior:
|
Regression tests related to Docker/provisioner behavior:
|
||||||
- `tests/test_docker_sandbox_mode_detection.py` (mode detection from `config.yaml`)
|
- `tests/test_docker_sandbox_mode_detection.py` (mode detection from `config.yaml`)
|
||||||
- `tests/test_provisioner_kubeconfig.py` (kubeconfig file/directory handling)
|
- `tests/test_provisioner_kubeconfig.py` (kubeconfig file/directory handling)
|
||||||
@@ -200,7 +192,7 @@ from deerflow.config import get_app_config
|
|||||||
|
|
||||||
### Middleware Chain
|
### Middleware Chain
|
||||||
|
|
||||||
Lead-agent middlewares are assembled in strict append order across `packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py` (`build_lead_runtime_middlewares`) and `packages/harness/deerflow/agents/lead_agent/agent.py` (`build_middlewares`):
|
Lead-agent middlewares are assembled in strict append order across `packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py` (`build_lead_runtime_middlewares`) and `packages/harness/deerflow/agents/lead_agent/agent.py` (`_build_middlewares`):
|
||||||
|
|
||||||
1. **ThreadDataMiddleware** - Creates per-thread directories under the user's isolation scope (`backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); resolves `user_id` via `get_effective_user_id()` (falls back to `"default"` in no-auth mode); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local thread directory
|
1. **ThreadDataMiddleware** - Creates per-thread directories under the user's isolation scope (`backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); resolves `user_id` via `get_effective_user_id()` (falls back to `"default"` in no-auth mode); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local thread directory
|
||||||
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
||||||
@@ -210,17 +202,16 @@ Lead-agent middlewares are assembled in strict append order across `packages/har
|
|||||||
6. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider.
|
6. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider.
|
||||||
7. **SandboxAuditMiddleware** - Audits sandboxed shell/file operations for security logging before tool execution continues
|
7. **SandboxAuditMiddleware** - Audits sandboxed shell/file operations for security logging before tool execution continues
|
||||||
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
|
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
|
||||||
9. **SkillActivationMiddleware** - Detects strict `/skill-name task` syntax on the latest real user message, resolves only enabled and runtime-allowed skills, reads `SKILL.md` from trusted skill storage, injects the skill body as hidden current-turn model context, and records a `middleware:skill_activation` audit event with skill name, category, path, and content hash
|
9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
|
||||||
10. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
|
10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
|
||||||
11. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
|
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id
|
||||||
12. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id
|
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
|
||||||
13. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
|
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
|
||||||
14. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
|
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
|
||||||
15. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
|
15. **DeferredToolFilterMiddleware** - Hides deferred (MCP) tool schemas from the bound model using a build-time deferred-name set + catalog hash, reading per-thread promotions from `ThreadState.promoted` (hash-scoped, no ContextVar); a tool becomes bound on subsequent turns after `tool_search` returns its schema (optional, if `tool_search.enabled`)
|
||||||
16. **DeferredToolFilterMiddleware** - Hides deferred (MCP) tool schemas from the bound model using a build-time deferred-name set + catalog hash, reading per-thread promotions from `ThreadState.promoted` (hash-scoped, no ContextVar); a tool becomes bound on subsequent turns after `tool_search` returns its schema (optional, if `tool_search.enabled`)
|
16. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if `subagent_enabled`)
|
||||||
17. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if `subagent_enabled`)
|
17. **LoopDetectionMiddleware** - Detects repeated tool-call loops; hard-stop responses clear both structured `tool_calls` and raw provider tool-call metadata before forcing a final text answer
|
||||||
18. **LoopDetectionMiddleware** - Detects repeated tool-call loops; hard-stop responses clear both structured `tool_calls` and raw provider tool-call metadata before forcing a final text answer
|
18. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
|
||||||
19. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
|
|
||||||
|
|
||||||
### Configuration System
|
### Configuration System
|
||||||
|
|
||||||
@@ -234,7 +225,7 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc
|
|||||||
|
|
||||||
**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state` — `lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work and passes it to `langgraph_runtime(app, startup_config)`.
|
**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state` — `lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work and passes it to `langgraph_runtime(app, startup_config)`.
|
||||||
|
|
||||||
Infrastructure fields are **restart-required**. The authoritative list lives in `packages/harness/deerflow/config/reload_boundary.py::STARTUP_ONLY_FIELDS` and is mirrored by the standardised `"startup-only:"` prefix on the corresponding `Field(description=...)` in `AppConfig`, so IDE hover on those fields surfaces the reason inline (no need to context-switch into this table). Currently registered: `database`, `checkpointer`, `run_events`, `stream_bridge`, `sandbox`, `log_level`, `channels`, `channel_connections`. Adding a new restart-required field requires updating the registry; drift is pinned by `tests/test_reload_boundary.py`.
|
Infrastructure fields are **restart-required**. The authoritative list lives in `packages/harness/deerflow/config/reload_boundary.py::STARTUP_ONLY_FIELDS` and is mirrored by the standardised `"startup-only:"` prefix on the corresponding `Field(description=...)` in `AppConfig`, so IDE hover on those fields surfaces the reason inline (no need to context-switch into this table). Currently registered: `database`, `checkpointer`, `run_events`, `stream_bridge`, `sandbox`, `log_level`, `channels`. Adding a new restart-required field requires updating the registry; drift is pinned by `tests/test_reload_boundary.py`.
|
||||||
|
|
||||||
Configuration priority:
|
Configuration priority:
|
||||||
1. Explicit `config_path` argument
|
1. Explicit `config_path` argument
|
||||||
@@ -272,7 +263,7 @@ CORS is same-origin by default when requests enter through nginx on port 2026. S
|
|||||||
| **Uploads** (`/api/threads/{id}/uploads`) | `POST /` - upload files (auto-converts PDF/PPT/Excel/Word); `GET /list` - list; `DELETE /{filename}` - delete |
|
| **Uploads** (`/api/threads/{id}/uploads`) | `POST /` - upload files (auto-converts PDF/PPT/Excel/Word); `GET /list` - list; `DELETE /{filename}` - delete |
|
||||||
| **Threads** (`/api/threads/{id}`) | `DELETE /` - remove DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
|
| **Threads** (`/api/threads/{id}`) | `DELETE /` - remove DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
|
||||||
| **Artifacts** (`/api/threads/{id}/artifacts`) | `GET /{path}` - serve artifacts; active content types (`text/html`, `application/xhtml+xml`, `image/svg+xml`) are always forced as download attachments to reduce XSS risk; `?download=true` still forces download for other file types |
|
| **Artifacts** (`/api/threads/{id}/artifacts`) | `GET /{path}` - serve artifacts; active content types (`text/html`, `application/xhtml+xml`, `image/svg+xml`) are always forced as download attachments to reduce XSS risk; `?download=true` still forces download for other file types |
|
||||||
| **Suggestions** (`/api/threads/{id}/suggestions`) | `POST /` - generate follow-up questions; rich list/block model content is normalized and inline reasoning (`<think>...</think>`, including unclosed/truncated blocks from reasoning models like MiniMax-M3) is stripped before JSON parsing |
|
| **Suggestions** (`/api/threads/{id}/suggestions`) | `POST /` - generate follow-up questions; rich list/block model content is normalized before JSON parsing |
|
||||||
| **Thread Runs** (`/api/threads/{id}/runs`) | `POST /` - create background run; `POST /stream` - create + SSE stream; `POST /wait` - create + block; `GET /` - list runs; `GET /{rid}` - run details; `POST /{rid}/cancel` - cancel; `GET /{rid}/join` - join SSE; `GET /{rid}/messages` - paginated messages `{data, has_more}`; `GET /{rid}/events` - full event stream; `GET /../messages` - thread messages with feedback; `GET /../token-usage` - aggregate tokens |
|
| **Thread Runs** (`/api/threads/{id}/runs`) | `POST /` - create background run; `POST /stream` - create + SSE stream; `POST /wait` - create + block; `GET /` - list runs; `GET /{rid}` - run details; `POST /{rid}/cancel` - cancel; `GET /{rid}/join` - join SSE; `GET /{rid}/messages` - paginated messages `{data, has_more}`; `GET /{rid}/events` - full event stream; `GET /../messages` - thread messages with feedback; `GET /../token-usage` - aggregate tokens |
|
||||||
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
|
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
|
||||||
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
|
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
|
||||||
@@ -292,7 +283,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
|||||||
**Provider Pattern**: `SandboxProvider` with `acquire`, `acquire_async`, `get`, `release` lifecycle. Async agent/tool paths call async sandbox lifecycle hooks so Docker sandbox creation, discovery, cross-process locking, readiness polling, and release stay off the event loop.
|
**Provider Pattern**: `SandboxProvider` with `acquire`, `acquire_async`, `get`, `release` lifecycle. Async agent/tool paths call async sandbox lifecycle hooks so Docker sandbox creation, discovery, cross-process locking, readiness polling, and release stay off the event loop.
|
||||||
**Implementations**:
|
**Implementations**:
|
||||||
- `LocalSandboxProvider` - Local filesystem execution. `acquire(thread_id)` returns a per-thread `LocalSandbox` (id `local:{thread_id}`) whose `path_mappings` resolve `/mnt/user-data/{workspace,uploads,outputs}` and `/mnt/acp-workspace` to that thread's host directories, so the public `Sandbox` API honours the `/mnt/user-data` contract uniformly with AIO. `acquire()` / `acquire(None)` keeps the legacy generic singleton (id `local`) for callers without a thread context. Per-thread sandboxes are held in an LRU cache (default 256 entries) guarded by a `threading.Lock`.
|
- `LocalSandboxProvider` - Local filesystem execution. `acquire(thread_id)` returns a per-thread `LocalSandbox` (id `local:{thread_id}`) whose `path_mappings` resolve `/mnt/user-data/{workspace,uploads,outputs}` and `/mnt/acp-workspace` to that thread's host directories, so the public `Sandbox` API honours the `/mnt/user-data` contract uniformly with AIO. `acquire()` / `acquire(None)` keeps the legacy generic singleton (id `local`) for callers without a thread context. Per-thread sandboxes are held in an LRU cache (default 256 entries) guarded by a `threading.Lock`.
|
||||||
- `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation. Active-cache and warm-pool entries are checked with the backend during acquire/reuse; definitively dead containers are dropped from all in-process maps so the thread can discover or create a fresh sandbox instead of reusing a stale client. Backend health-check failures are treated as unknown, not dead; local discovery likewise treats an unverifiable container as not adoptable and falls through to create rather than failing acquire. `get()` remains an in-memory lookup for event-loop-safe tool paths.
|
- `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation
|
||||||
|
|
||||||
**Virtual Path System**:
|
**Virtual Path System**:
|
||||||
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
|
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
|
||||||
@@ -314,7 +305,6 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
|||||||
**Concurrency**: `MAX_CONCURRENT_SUBAGENTS = 3` enforced by `SubagentLimitMiddleware` (truncates excess tool calls in `after_model`), 15-minute timeout
|
**Concurrency**: `MAX_CONCURRENT_SUBAGENTS = 3` enforced by `SubagentLimitMiddleware` (truncates excess tool calls in `after_model`), 15-minute timeout
|
||||||
**Flow**: `task()` tool → `SubagentExecutor` → background thread → poll 5s → SSE events → result
|
**Flow**: `task()` tool → `SubagentExecutor` → background thread → poll 5s → SSE events → result
|
||||||
**Events**: `task_started`, `task_running`, `task_completed`/`task_failed`/`task_timed_out`
|
**Events**: `task_started`, `task_running`, `task_completed`/`task_failed`/`task_timed_out`
|
||||||
**Deferred MCP tools** (if `tool_search.enabled`): `SubagentExecutor._build_initial_state` assembles deferral after policy filtering via the shared `assemble_deferred_tools` (fail-closed), appends the `tool_search` tool, injects the `<available-deferred-tools>` section into the subagent's `SystemMessage`, and threads the setup to `_create_agent`, which attaches `DeferredToolFilterMiddleware` through `build_subagent_runtime_middlewares(deferred_setup=...)`. Subagents thus withhold full MCP schemas until promotion, same as the lead agent; each task run gets a fresh `ThreadState` so promotion is isolated per run
|
|
||||||
|
|
||||||
### Tool System (`packages/harness/deerflow/tools/`)
|
### Tool System (`packages/harness/deerflow/tools/`)
|
||||||
|
|
||||||
@@ -357,7 +347,6 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
|||||||
- **Format**: Directory with `SKILL.md` (YAML frontmatter: name, description, license, allowed-tools)
|
- **Format**: Directory with `SKILL.md` (YAML frontmatter: name, description, license, allowed-tools)
|
||||||
- **Loading**: `load_skills()` recursively scans `skills/{public,custom}` for `SKILL.md`, parses metadata, and reads enabled state from extensions_config.json
|
- **Loading**: `load_skills()` recursively scans `skills/{public,custom}` for `SKILL.md`, parses metadata, and reads enabled state from extensions_config.json
|
||||||
- **Injection**: Enabled skills listed in agent system prompt with container paths
|
- **Injection**: Enabled skills listed in agent system prompt with container paths
|
||||||
- **Slash activation**: `/skill-name task` loads that enabled skill's `SKILL.md` for the current model call only. The resolver rejects leading whitespace, missing separators, reserved channel commands (`/new`, `/help`, `/bootstrap`, `/status`, `/models`, `/memory`), disabled skills, and skills outside a custom agent's whitelist.
|
|
||||||
- **Installation**: `POST /api/skills/install` extracts .skill ZIP archive to custom/ directory
|
- **Installation**: `POST /api/skills/install` extracts .skill ZIP archive to custom/ directory
|
||||||
|
|
||||||
### Model Factory (`packages/harness/deerflow/models/factory.py`)
|
### Model Factory (`packages/harness/deerflow/models/factory.py`)
|
||||||
@@ -377,32 +366,29 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
|||||||
|
|
||||||
### IM Channels System (`app/channels/`)
|
### IM Channels System (`app/channels/`)
|
||||||
|
|
||||||
Bridges external messaging platforms (Feishu, Slack, Telegram, Discord, DingTalk) to the DeerFlow agent via Gateway's LangGraph-compatible API.
|
Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the DeerFlow agent via Gateway's LangGraph-compatible API.
|
||||||
|
|
||||||
|
|
||||||
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies.
|
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies.
|
||||||
|
|
||||||
**Components**:
|
**Components**:
|
||||||
- `message_bus.py` - Async pub/sub hub (`InboundMessage` → queue → dispatcher; `OutboundMessage` → callbacks → channels)
|
- `message_bus.py` - Async pub/sub hub (`InboundMessage` → queue → dispatcher; `OutboundMessage` → callbacks → channels)
|
||||||
- `store.py` - JSON-file persistence mapping `channel_name:chat_id[:topic_id]` → `thread_id` (keys are `channel:chat` for root conversations and `channel:chat:topic` for threaded conversations)
|
- `store.py` - JSON-file persistence mapping `channel_name:chat_id[:topic_id]` → `thread_id` (keys are `channel:chat` for root conversations and `channel:chat:topic` for threaded conversations)
|
||||||
- `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Discord on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu/Telegram incremental outbound updates
|
- `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Telegram on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu incremental outbound updates
|
||||||
- `base.py` - Abstract `Channel` base class (start/stop/send lifecycle)
|
- `base.py` - Abstract `Channel` base class (start/stop/send lifecycle)
|
||||||
- `service.py` - Manages lifecycle of all configured channels from `config.yaml`
|
- `service.py` - Manages lifecycle of all configured channels from `config.yaml`
|
||||||
- `slack.py` / `feishu.py` / `telegram.py` / `discord.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `telegram.py` registers the "Working on it..." placeholder as the stream target and edits it in place via `editMessageText`; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured)
|
- `slack.py` / `feishu.py` / `telegram.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured)
|
||||||
- `app/gateway/routers/channel_connections.py` - Browser-facing user connection and disconnect APIs
|
|
||||||
- `deerflow.persistence.channel_connections` - SQL-backed user-owned connection, optional credential, connect state, and conversation store
|
|
||||||
|
|
||||||
**Message Flow**:
|
**Message Flow**:
|
||||||
1. External platform -> Channel impl -> `MessageBus.publish_inbound()`
|
1. External platform -> Channel impl -> `MessageBus.publish_inbound()`
|
||||||
2. `ChannelManager._dispatch_loop()` consumes from queue
|
2. `ChannelManager._dispatch_loop()` consumes from queue
|
||||||
3. For user-owned channel connections, incoming messages carry `connection_id`, `owner_user_id`, and `workspace_id`; `owner_user_id` becomes the DeerFlow run `user_id`, while the raw platform user id remains `channel_user_id`
|
3. For chat: look up/create thread through Gateway's LangGraph-compatible API
|
||||||
4. For chat: look up/create thread through Gateway's LangGraph-compatible API
|
4. Feishu chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`)
|
||||||
5. Feishu/Telegram chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`)
|
5. Slack/Telegram chat: `runs.wait()` → extract final response → publish outbound
|
||||||
6. Slack/Discord chat: `runs.wait()` → extract final response → publish outbound
|
6. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement)
|
||||||
7. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement)
|
7. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails
|
||||||
8. Telegram streaming: the "Working on it..." placeholder message is registered as the stream target; non-final updates `editMessageText` it in place (channel-side throttle: 1s in private chats, 3s in groups due to Telegram's 20 msg/min group cap; 4096-char truncation; rate-limited updates dropped); the final update performs the last edit and splits >4096 texts into follow-up messages
|
8. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
|
||||||
9. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails
|
9. Outbound → channel callbacks → platform reply
|
||||||
10. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
|
|
||||||
11. Outbound → channel callbacks → platform reply
|
|
||||||
|
|
||||||
**Configuration** (`config.yaml` -> `channels`):
|
**Configuration** (`config.yaml` -> `channels`):
|
||||||
- `langgraph_url` - LangGraph-compatible Gateway API base URL (default: `http://localhost:8001/api`)
|
- `langgraph_url` - LangGraph-compatible Gateway API base URL (default: `http://localhost:8001/api`)
|
||||||
@@ -410,17 +396,6 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, Discord, DingTalk
|
|||||||
- In Docker Compose, IM channels run inside the `gateway` container, so `localhost` points back to that container. Use `http://gateway:8001/api` for `langgraph_url` and `http://gateway:8001` for `gateway_url`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` / `DEER_FLOW_CHANNELS_GATEWAY_URL`.
|
- In Docker Compose, IM channels run inside the `gateway` container, so `localhost` points back to that container. Use `http://gateway:8001/api` for `langgraph_url` and `http://gateway:8001` for `gateway_url`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` / `DEER_FLOW_CHANNELS_GATEWAY_URL`.
|
||||||
- Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token), `dingtalk` (client_id, client_secret, optional `card_template_id` for AI Card streaming)
|
- Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token), `dingtalk` (client_id, client_secret, optional `card_template_id` for AI Card streaming)
|
||||||
|
|
||||||
**User-owned channel connections** (`config.yaml` -> `channel_connections`):
|
|
||||||
- Disabled by default. It is a user-binding layer on top of the existing `channels.*` runtime config, not a replacement for provider bot credentials.
|
|
||||||
- No public IP, OAuth callback URL, or provider webhook route is required by the current implementation.
|
|
||||||
- Telegram uses a deep-link `/start <code>` flow over the existing long-polling worker. Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom use `/connect <code>` over their existing outbound channel workers.
|
|
||||||
- Frontend APIs: `GET /api/channels/providers`, `GET /api/channels/connections`, `POST /api/channels/{provider}/connect`, and `DELETE /api/channels/connections/{connection_id}`.
|
|
||||||
- Browser APIs remain protected by normal Gateway auth/CSRF. Provider messages arrive through the already-configured channel workers.
|
|
||||||
- Provider-level `connection_status` reflects the user's newest connection row. With no binding it is `not_connected`, except in auth-disabled local mode where a configured running channel reports `connected` because all channel messages already route to the default user.
|
|
||||||
- Slack replies use the configured operator bot token from `channels.slack` unless per-connection credentials are present; unreadable or corrupt stored credentials are treated as unavailable.
|
|
||||||
- Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom workers resolve incoming platform identities to connection records before reaching `ChannelManager`.
|
|
||||||
- See `backend/docs/IM_CHANNEL_CONNECTIONS.md` for provider setup and operational notes.
|
|
||||||
|
|
||||||
|
|
||||||
### Memory System (`packages/harness/deerflow/agents/memory/`)
|
### Memory System (`packages/harness/deerflow/agents/memory/`)
|
||||||
|
|
||||||
@@ -451,12 +426,6 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, Discord, DingTalk
|
|||||||
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
|
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
|
||||||
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
|
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
|
||||||
|
|
||||||
**Token counting** (`packages/harness/deerflow/agents/memory/prompt.py`):
|
|
||||||
- `_count_tokens` budgets the injection. In default `tiktoken` mode, the encoding is loaded lazily and cached.
|
|
||||||
- Failed tiktoken loads are cached with a timestamp. During the fixed cooldown (`_TIKTOKEN_RETRY_COOLDOWN_S`, 600s), callers fall back to char estimation immediately instead of re-triggering the blocking BPE download; after the cooldown, transient outages can self-heal without a restart.
|
|
||||||
- In-flight loads are cached as a LOADING sentinel so concurrent callers fall back instead of spawning more blocking threads.
|
|
||||||
- Set `memory.token_counting: char` to skip tiktoken entirely and use the network-free CJK-aware char estimate.
|
|
||||||
|
|
||||||
Focused regression coverage for the updater lives in `backend/tests/test_memory_updater.py`.
|
Focused regression coverage for the updater lives in `backend/tests/test_memory_updater.py`.
|
||||||
|
|
||||||
**Configuration** (`config.yaml` → `memory`):
|
**Configuration** (`config.yaml` → `memory`):
|
||||||
@@ -466,7 +435,6 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_
|
|||||||
- `model_name` - LLM for updates (null = default model)
|
- `model_name` - LLM for updates (null = default model)
|
||||||
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
|
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
|
||||||
- `max_injection_tokens` - Token limit for prompt injection (2000)
|
- `max_injection_tokens` - Token limit for prompt injection (2000)
|
||||||
- `token_counting` - Token counting strategy for the injection budget: `tiktoken` (default, accurate but may download BPE data from a public endpoint on first use — can block for a long time in network-restricted environments, see issues #3402/#3429) or `char` (network-free CJK-aware char estimate, never touches tiktoken)
|
|
||||||
|
|
||||||
### Reflection System (`packages/harness/deerflow/reflection/`)
|
### Reflection System (`packages/harness/deerflow/reflection/`)
|
||||||
|
|
||||||
@@ -524,7 +492,7 @@ Both can be modified at runtime via Gateway API endpoints or `DeerFlowClient` me
|
|||||||
- `"messages-tuple"` — per-chunk update: for AI text this is a **delta** (concat per `id` to rebuild the full message); tool calls and tool results are emitted once each
|
- `"messages-tuple"` — per-chunk update: for AI text this is a **delta** (concat per `id` to rebuild the full message); tool calls and tool results are emitted once each
|
||||||
- `"custom"` — forwarded from `StreamWriter`
|
- `"custom"` — forwarded from `StreamWriter`
|
||||||
- `"end"` — stream finished (carries cumulative `usage` counted once per message id)
|
- `"end"` — stream finished (carries cumulative `usage` counted once per message id)
|
||||||
- Agent created lazily via `create_agent()` + `build_middlewares()`, same as `make_lead_agent`
|
- Agent created lazily via `create_agent()` + `_build_middlewares()`, same as `make_lead_agent`
|
||||||
- Supports `checkpointer` parameter for state persistence across turns
|
- Supports `checkpointer` parameter for state persistence across turns
|
||||||
- `reset_agent()` forces agent recreation (e.g. after memory or skill changes)
|
- `reset_agent()` forces agent recreation (e.g. after memory or skill changes)
|
||||||
- See [docs/STREAMING.md](docs/STREAMING.md) for the full design: why Gateway and DeerFlowClient are parallel paths, LangGraph's `stream_mode` semantics, the per-id dedup invariants, and regression testing strategy
|
- See [docs/STREAMING.md](docs/STREAMING.md) for the full design: why Gateway and DeerFlowClient are parallel paths, LangGraph's `stream_mode` semantics, the per-id dedup invariants, and regression testing strategy
|
||||||
|
|||||||
+1
-1
@@ -69,7 +69,7 @@ Middlewares execute in strict order, each handling a specific concern:
|
|||||||
Per-thread isolated execution with virtual path translation:
|
Per-thread isolated execution with virtual path translation:
|
||||||
|
|
||||||
- **Abstract interface**: `execute_command`, `read_file`, `write_file`, `list_dir`
|
- **Abstract interface**: `execute_command`, `read_file`, `write_file`, `list_dir`
|
||||||
- **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/). Async runtime paths use async sandbox lifecycle hooks so startup, readiness polling, and release do not block the event loop. `AioSandboxProvider` validates active-cache and warm-pool containers during acquire/reuse, dropping definitively dead entries so a thread can provision a fresh sandbox after an unexpected container exit while keeping `get()` as an in-memory lookup. Backend health-check failures are treated as unknown, not dead, and a container that cannot be verified during discovery is simply not adopted (acquire falls through to create instead of failing).
|
- **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/). Async runtime paths use async sandbox lifecycle hooks so startup, readiness polling, and release do not block the event loop.
|
||||||
- **Virtual paths**: `/mnt/user-data/{workspace,uploads,outputs}` → thread-specific physical directories
|
- **Virtual paths**: `/mnt/user-data/{workspace,uploads,outputs}` → thread-specific physical directories
|
||||||
- **Skills path**: `/mnt/skills` → `deer-flow/skills/` directory
|
- **Skills path**: `/mnt/skills` → `deer-flow/skills/` directory
|
||||||
- **Skills loading**: Recursively discovers nested `SKILL.md` files under `skills/{public,custom}` and preserves nested container paths
|
- **Skills loading**: Recursively discovers nested `SKILL.md` files under `skills/{public,custom}` and preserves nested container paths
|
||||||
|
|||||||
@@ -18,21 +18,3 @@ KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset(
|
|||||||
"/help",
|
"/help",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_connect_code(text: str) -> str | None:
|
|
||||||
"""Extract the one-time channel binding code from a connect command."""
|
|
||||||
parts = text.strip().split()
|
|
||||||
if len(parts) < 2:
|
|
||||||
return None
|
|
||||||
command = parts[0].lower()
|
|
||||||
if command in {"/connect", "connect"}:
|
|
||||||
return parts[1]
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def is_known_channel_command(text: str) -> bool:
|
|
||||||
"""Return whether text starts with a registered channel control command."""
|
|
||||||
if not text.startswith("/"):
|
|
||||||
return False
|
|
||||||
return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS
|
|
||||||
|
|||||||
@@ -1,44 +0,0 @@
|
|||||||
"""Helpers for attaching persisted channel connection ownership to inbound messages."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from app.channels.message_bus import InboundMessage
|
|
||||||
|
|
||||||
|
|
||||||
async def attach_connection_identity(
|
|
||||||
inbound: InboundMessage,
|
|
||||||
*,
|
|
||||||
repo: Any,
|
|
||||||
provider: str,
|
|
||||||
workspace_id: str | None,
|
|
||||||
fallback_without_workspace: bool = False,
|
|
||||||
) -> InboundMessage:
|
|
||||||
"""Attach connection metadata to an inbound message when a persisted binding exists."""
|
|
||||||
if repo is None:
|
|
||||||
return inbound
|
|
||||||
|
|
||||||
workspace_candidates: list[str | None] = []
|
|
||||||
if workspace_id:
|
|
||||||
workspace_candidates.append(workspace_id)
|
|
||||||
if fallback_without_workspace:
|
|
||||||
workspace_candidates.append(None)
|
|
||||||
if not workspace_candidates:
|
|
||||||
return inbound
|
|
||||||
|
|
||||||
for candidate in workspace_candidates:
|
|
||||||
connection = await repo.find_connection_by_external_identity(
|
|
||||||
provider=provider,
|
|
||||||
external_account_id=inbound.user_id,
|
|
||||||
workspace_id=candidate,
|
|
||||||
)
|
|
||||||
if connection is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
inbound.connection_id = connection["id"]
|
|
||||||
inbound.owner_user_id = connection["owner_user_id"]
|
|
||||||
inbound.workspace_id = connection.get("workspace_id")
|
|
||||||
return inbound
|
|
||||||
|
|
||||||
return inbound
|
|
||||||
@@ -14,8 +14,7 @@ from typing import Any
|
|||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||||
from app.channels.connection_identity import attach_connection_identity
|
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -60,7 +59,9 @@ def _normalize_allowed_users(allowed_users: Any) -> set[str]:
|
|||||||
|
|
||||||
|
|
||||||
def _is_dingtalk_command(text: str) -> bool:
|
def _is_dingtalk_command(text: str) -> bool:
|
||||||
return is_known_channel_command(text)
|
if not text.startswith("/"):
|
||||||
|
return False
|
||||||
|
return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS
|
||||||
|
|
||||||
|
|
||||||
def _extract_text_from_rich_text(rich_text_list: list) -> str:
|
def _extract_text_from_rich_text(rich_text_list: list) -> str:
|
||||||
@@ -137,7 +138,6 @@ class DingTalkChannel(Channel):
|
|||||||
self._incoming_messages: dict[str, Any] = {}
|
self._incoming_messages: dict[str, Any] = {}
|
||||||
self._incoming_messages_lock = threading.Lock()
|
self._incoming_messages_lock = threading.Lock()
|
||||||
self._card_repliers: dict[str, Any] = {}
|
self._card_repliers: dict[str, Any] = {}
|
||||||
self._connection_repo = config.get("connection_repo")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supports_streaming(self) -> bool:
|
def supports_streaming(self) -> bool:
|
||||||
@@ -397,24 +397,6 @@ class DingTalkChannel(Channel):
|
|||||||
text[:100],
|
text[:100],
|
||||||
)
|
)
|
||||||
|
|
||||||
connect_code = extract_connect_code(text)
|
|
||||||
if connect_code and self._connection_repo is not None:
|
|
||||||
if self._main_loop and self._main_loop.is_running():
|
|
||||||
fut = asyncio.run_coroutine_threadsafe(
|
|
||||||
self._bind_connection_from_connect_code(
|
|
||||||
conversation_type=conversation_type,
|
|
||||||
sender_staff_id=sender_staff_id,
|
|
||||||
sender_nick=sender_nick,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
code=connect_code,
|
|
||||||
),
|
|
||||||
self._main_loop,
|
|
||||||
)
|
|
||||||
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
|
|
||||||
else:
|
|
||||||
logger.warning("[DingTalk] main loop not running, cannot bind channel connection")
|
|
||||||
return
|
|
||||||
|
|
||||||
if _is_dingtalk_command(text):
|
if _is_dingtalk_command(text):
|
||||||
msg_type = InboundMessageType.COMMAND
|
msg_type = InboundMessageType.COMMAND
|
||||||
else:
|
else:
|
||||||
@@ -470,95 +452,11 @@ class DingTalkChannel(Channel):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def _prepare_inbound(self, chat_id: str, inbound: InboundMessage) -> None:
|
async def _prepare_inbound(self, chat_id: str, inbound: InboundMessage) -> None:
|
||||||
inbound = await self._attach_connection_identity(inbound)
|
|
||||||
# Running reply must finish before publish_inbound so AI card tracks are
|
# Running reply must finish before publish_inbound so AI card tracks are
|
||||||
# registered before the manager emits streaming outbounds.
|
# registered before the manager emits streaming outbounds.
|
||||||
await self._send_running_reply(chat_id, inbound)
|
await self._send_running_reply(chat_id, inbound)
|
||||||
await self.bus.publish_inbound(inbound)
|
await self.bus.publish_inbound(inbound)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _connection_workspace_id(conversation_type: str, conversation_id: str) -> str | None:
|
|
||||||
if conversation_type == _CONVERSATION_TYPE_GROUP and conversation_id:
|
|
||||||
return conversation_id
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
|
||||||
conversation_type = str(inbound.metadata.get("conversation_type") or _CONVERSATION_TYPE_P2P)
|
|
||||||
conversation_id = str(inbound.metadata.get("conversation_id") or "")
|
|
||||||
return await attach_connection_identity(
|
|
||||||
inbound,
|
|
||||||
repo=self._connection_repo,
|
|
||||||
provider="dingtalk",
|
|
||||||
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
|
|
||||||
fallback_without_workspace=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _bind_connection_from_connect_code(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
conversation_type: str,
|
|
||||||
sender_staff_id: str,
|
|
||||||
sender_nick: str,
|
|
||||||
conversation_id: str,
|
|
||||||
code: str,
|
|
||||||
) -> bool:
|
|
||||||
if self._connection_repo is None or not code:
|
|
||||||
return False
|
|
||||||
|
|
||||||
state = await self._connection_repo.consume_oauth_state(provider="dingtalk", state=code)
|
|
||||||
if state is None:
|
|
||||||
await self._send_connection_reply(
|
|
||||||
conversation_type,
|
|
||||||
sender_staff_id,
|
|
||||||
conversation_id,
|
|
||||||
"DingTalk connection code is invalid or expired.",
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
if not sender_staff_id:
|
|
||||||
await self._send_connection_reply(
|
|
||||||
conversation_type,
|
|
||||||
sender_staff_id,
|
|
||||||
conversation_id,
|
|
||||||
"DingTalk connection could not be completed from this message.",
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
await self._connection_repo.upsert_connection(
|
|
||||||
owner_user_id=state["owner_user_id"],
|
|
||||||
provider="dingtalk",
|
|
||||||
external_account_id=sender_staff_id,
|
|
||||||
external_account_name=sender_nick or None,
|
|
||||||
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
|
|
||||||
metadata={
|
|
||||||
"conversation_type": conversation_type,
|
|
||||||
"conversation_id": conversation_id,
|
|
||||||
},
|
|
||||||
status="connected",
|
|
||||||
)
|
|
||||||
await self._send_connection_reply(
|
|
||||||
conversation_type,
|
|
||||||
sender_staff_id,
|
|
||||||
conversation_id,
|
|
||||||
"DingTalk connected to DeerFlow.",
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _send_connection_reply(
|
|
||||||
self,
|
|
||||||
conversation_type: str,
|
|
||||||
sender_staff_id: str,
|
|
||||||
conversation_id: str,
|
|
||||||
text: str,
|
|
||||||
) -> None:
|
|
||||||
robot_code = self._client_id
|
|
||||||
if conversation_type == _CONVERSATION_TYPE_GROUP:
|
|
||||||
if conversation_id:
|
|
||||||
await self._send_text_message_to_group(robot_code, conversation_id, text)
|
|
||||||
return
|
|
||||||
if sender_staff_id:
|
|
||||||
await self._send_text_message_to_user(robot_code, sender_staff_id, text)
|
|
||||||
|
|
||||||
async def _send_running_reply(self, chat_id: str, inbound: InboundMessage) -> None:
|
async def _send_running_reply(self, chat_id: str, inbound: InboundMessage) -> None:
|
||||||
conversation_type = inbound.metadata.get("conversation_type", _CONVERSATION_TYPE_P2P)
|
conversation_type = inbound.metadata.get("conversation_type", _CONVERSATION_TYPE_P2P)
|
||||||
sender_staff_id = inbound.metadata.get("sender_staff_id", "")
|
sender_staff_id = inbound.metadata.get("sender_staff_id", "")
|
||||||
|
|||||||
@@ -10,9 +10,7 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
from app.channels.connection_identity import attach_connection_identity
|
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -71,7 +69,6 @@ class DiscordChannel(Channel):
|
|||||||
self._discord_loop: asyncio.AbstractEventLoop | None = None
|
self._discord_loop: asyncio.AbstractEventLoop | None = None
|
||||||
self._main_loop: asyncio.AbstractEventLoop | None = None
|
self._main_loop: asyncio.AbstractEventLoop | None = None
|
||||||
self._discord_module = None
|
self._discord_module = None
|
||||||
self._connection_repo = config.get("connection_repo")
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self._running:
|
if self._running:
|
||||||
@@ -289,10 +286,6 @@ class DiscordChannel(Channel):
|
|||||||
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
|
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
|
||||||
# Don't return early if text is empty — still process the mention (e.g., create thread)
|
# Don't return early if text is empty — still process the mention (e.g., create thread)
|
||||||
|
|
||||||
connect_code = extract_connect_code(text)
|
|
||||||
if connect_code and await self._bind_connection_from_connect_code(message, connect_code):
|
|
||||||
return
|
|
||||||
|
|
||||||
# --- Determine thread/channel routing and typing target ---
|
# --- Determine thread/channel routing and typing target ---
|
||||||
thread_id = None
|
thread_id = None
|
||||||
chat_id = None
|
chat_id = None
|
||||||
@@ -307,7 +300,7 @@ class DiscordChannel(Channel):
|
|||||||
|
|
||||||
# If this is a known active thread, process normally
|
# If this is a known active thread, process normally
|
||||||
if thread_id in self._active_thread_ids:
|
if thread_id in self._active_thread_ids:
|
||||||
msg_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
|
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
||||||
inbound = self._make_inbound(
|
inbound = self._make_inbound(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
user_id=str(message.author.id),
|
user_id=str(message.author.id),
|
||||||
@@ -321,7 +314,6 @@ class DiscordChannel(Channel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
inbound.topic_id = thread_id
|
inbound.topic_id = thread_id
|
||||||
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
|
|
||||||
self._publish(inbound)
|
self._publish(inbound)
|
||||||
# Start typing indicator in the thread
|
# Start typing indicator in the thread
|
||||||
if typing_target:
|
if typing_target:
|
||||||
@@ -415,7 +407,7 @@ class DiscordChannel(Channel):
|
|||||||
chat_id = channel_id
|
chat_id = channel_id
|
||||||
typing_target = message.channel # Type into the channel
|
typing_target = message.channel # Type into the channel
|
||||||
|
|
||||||
msg_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
|
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
||||||
inbound = self._make_inbound(
|
inbound = self._make_inbound(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
user_id=str(message.author.id),
|
user_id=str(message.author.id),
|
||||||
@@ -429,7 +421,6 @@ class DiscordChannel(Channel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
inbound.topic_id = thread_id
|
inbound.topic_id = thread_id
|
||||||
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
|
|
||||||
|
|
||||||
# Start typing indicator in the correct target (thread or channel)
|
# Start typing indicator in the correct target (thread or channel)
|
||||||
if typing_target:
|
if typing_target:
|
||||||
@@ -444,60 +435,6 @@ class DiscordChannel(Channel):
|
|||||||
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
|
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
|
||||||
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
||||||
|
|
||||||
async def _attach_connection_identity(self, inbound: InboundMessage, guild_id: str | None = None) -> InboundMessage:
|
|
||||||
return await attach_connection_identity(
|
|
||||||
inbound,
|
|
||||||
repo=self._connection_repo,
|
|
||||||
provider="discord",
|
|
||||||
workspace_id=guild_id,
|
|
||||||
fallback_without_workspace=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _bind_connection_from_connect_code(self, message, code: str) -> bool:
|
|
||||||
if self._connection_repo is None or not code:
|
|
||||||
return False
|
|
||||||
|
|
||||||
state = await self._connection_repo.consume_oauth_state(provider="discord", state=code)
|
|
||||||
if state is None:
|
|
||||||
await self._send_connection_reply(message, "Discord connection code is invalid or expired.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
guild = getattr(message, "guild", None)
|
|
||||||
channel = getattr(message, "channel", None)
|
|
||||||
author = getattr(message, "author", None)
|
|
||||||
user_id = str(getattr(author, "id", "") or "")
|
|
||||||
if not user_id:
|
|
||||||
await self._send_connection_reply(message, "Discord connection could not be completed from this message.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
guild_id = str(getattr(guild, "id", "") or "") or None
|
|
||||||
await self._connection_repo.upsert_connection(
|
|
||||||
owner_user_id=state["owner_user_id"],
|
|
||||||
provider="discord",
|
|
||||||
external_account_id=user_id,
|
|
||||||
external_account_name=getattr(author, "display_name", None) or getattr(author, "name", None),
|
|
||||||
workspace_id=guild_id,
|
|
||||||
workspace_name=getattr(guild, "name", None) if guild is not None else None,
|
|
||||||
metadata={
|
|
||||||
"guild_id": guild_id,
|
|
||||||
"channel_id": str(getattr(channel, "id", "") or ""),
|
|
||||||
},
|
|
||||||
status="connected",
|
|
||||||
)
|
|
||||||
await self._send_connection_reply(message, "Discord connected to DeerFlow.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _send_connection_reply(message, text: str) -> None:
|
|
||||||
channel = getattr(message, "channel", None)
|
|
||||||
send = getattr(channel, "send", None)
|
|
||||||
if send is None:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
await send(text)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[Discord] failed to send connection reply")
|
|
||||||
|
|
||||||
def _run_client(self) -> None:
|
def _run_client(self) -> None:
|
||||||
self._discord_loop = asyncio.new_event_loop()
|
self._discord_loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(self._discord_loop)
|
asyncio.set_event_loop(self._discord_loop)
|
||||||
|
|||||||
@@ -11,8 +11,7 @@ import time
|
|||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||||
from app.channels.connection_identity import attach_connection_identity
|
|
||||||
from app.channels.message_bus import (
|
from app.channels.message_bus import (
|
||||||
PENDING_CLARIFICATION_METADATA_KEY,
|
PENDING_CLARIFICATION_METADATA_KEY,
|
||||||
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY,
|
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY,
|
||||||
@@ -31,7 +30,9 @@ PENDING_CLARIFICATION_TTL_SECONDS = 30 * 60
|
|||||||
|
|
||||||
|
|
||||||
def _is_feishu_command(text: str) -> bool:
|
def _is_feishu_command(text: str) -> bool:
|
||||||
return is_known_channel_command(text)
|
if not text.startswith("/"):
|
||||||
|
return False
|
||||||
|
return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS
|
||||||
|
|
||||||
|
|
||||||
class FeishuChannel(Channel):
|
class FeishuChannel(Channel):
|
||||||
@@ -72,7 +73,6 @@ class FeishuChannel(Channel):
|
|||||||
self._CreateImageRequestBody = None
|
self._CreateImageRequestBody = None
|
||||||
self._GetMessageResourceRequest = None
|
self._GetMessageResourceRequest = None
|
||||||
self._thread_lock = threading.Lock()
|
self._thread_lock = threading.Lock()
|
||||||
self._connection_repo = config.get("connection_repo")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _non_empty_str(value: Any) -> str | None:
|
def _non_empty_str(value: Any) -> str | None:
|
||||||
@@ -88,23 +88,6 @@ class FeishuChannel(Channel):
|
|||||||
def supports_streaming(self) -> bool:
|
def supports_streaming(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@property
|
|
||||||
def is_running(self) -> bool:
|
|
||||||
if not self._running:
|
|
||||||
return False
|
|
||||||
return self._thread is not None and self._thread.is_alive()
|
|
||||||
|
|
||||||
def _build_event_handler(self, lark):
|
|
||||||
return (
|
|
||||||
lark.EventDispatcherHandler.builder("", "")
|
|
||||||
.register_p2_im_message_receive_v1(self._on_message)
|
|
||||||
.register_p2_im_message_message_read_v1(self._on_ignored_message_event)
|
|
||||||
.register_p2_im_message_reaction_created_v1(self._on_ignored_message_event)
|
|
||||||
.register_p2_im_message_reaction_deleted_v1(self._on_ignored_message_event)
|
|
||||||
.register_p2_im_message_recalled_v1(self._on_ignored_message_event)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self._running:
|
if self._running:
|
||||||
return
|
return
|
||||||
@@ -198,7 +181,7 @@ class FeishuChannel(Channel):
|
|||||||
# thread's uvloop.
|
# thread's uvloop.
|
||||||
_ws_client_mod.loop = loop
|
_ws_client_mod.loop = loop
|
||||||
|
|
||||||
event_handler = self._build_event_handler(lark)
|
event_handler = lark.EventDispatcherHandler.builder("", "").register_p2_im_message_receive_v1(self._on_message).build()
|
||||||
ws_client = lark.ws.Client(
|
ws_client = lark.ws.Client(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
app_secret=app_secret,
|
app_secret=app_secret,
|
||||||
@@ -210,10 +193,6 @@ class FeishuChannel(Channel):
|
|||||||
except Exception:
|
except Exception:
|
||||||
if self._running:
|
if self._running:
|
||||||
logger.exception("Feishu WebSocket error")
|
logger.exception("Feishu WebSocket error")
|
||||||
self._running = False
|
|
||||||
|
|
||||||
def _on_ignored_message_event(self, event) -> None:
|
|
||||||
logger.debug("[Feishu] ignoring non-content message event: %s", type(event).__name__)
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
self._running = False
|
self._running = False
|
||||||
@@ -749,47 +728,11 @@ class FeishuChannel(Channel):
|
|||||||
|
|
||||||
async def _prepare_inbound(self, msg_id: str, inbound) -> None:
|
async def _prepare_inbound(self, msg_id: str, inbound) -> None:
|
||||||
"""Kick off Feishu side effects without delaying inbound dispatch."""
|
"""Kick off Feishu side effects without delaying inbound dispatch."""
|
||||||
inbound = await self._attach_connection_identity(inbound)
|
|
||||||
reaction_task = asyncio.create_task(self._add_reaction(msg_id, "OK"))
|
reaction_task = asyncio.create_task(self._add_reaction(msg_id, "OK"))
|
||||||
self._track_background_task(reaction_task, name="add_reaction", msg_id=msg_id)
|
self._track_background_task(reaction_task, name="add_reaction", msg_id=msg_id)
|
||||||
self._ensure_running_card_started(msg_id)
|
self._ensure_running_card_started(msg_id)
|
||||||
await self.bus.publish_inbound(inbound)
|
await self.bus.publish_inbound(inbound)
|
||||||
|
|
||||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
|
||||||
return await attach_connection_identity(
|
|
||||||
inbound,
|
|
||||||
repo=self._connection_repo,
|
|
||||||
provider="feishu",
|
|
||||||
workspace_id=inbound.chat_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _bind_connection_from_connect_code(self, *, message_id: str, chat_id: str, user_id: str, code: str) -> bool:
|
|
||||||
if self._connection_repo is None or not code:
|
|
||||||
return False
|
|
||||||
|
|
||||||
state = await self._connection_repo.consume_oauth_state(provider="feishu", state=code)
|
|
||||||
if state is None:
|
|
||||||
await self._reply_card(message_id, "Feishu connection code is invalid or expired.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
if not user_id or not chat_id:
|
|
||||||
await self._reply_card(message_id, "Feishu connection could not be completed from this message.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
await self._connection_repo.upsert_connection(
|
|
||||||
owner_user_id=state["owner_user_id"],
|
|
||||||
provider="feishu",
|
|
||||||
external_account_id=user_id,
|
|
||||||
workspace_id=chat_id,
|
|
||||||
metadata={
|
|
||||||
"chat_id": chat_id,
|
|
||||||
"message_id": message_id,
|
|
||||||
},
|
|
||||||
status="connected",
|
|
||||||
)
|
|
||||||
await self._reply_card(message_id, "Feishu connected to DeerFlow.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _on_message(self, event) -> None:
|
def _on_message(self, event) -> None:
|
||||||
"""Called by lark-oapi when a message is received (runs in lark thread)."""
|
"""Called by lark-oapi when a message is received (runs in lark thread)."""
|
||||||
try:
|
try:
|
||||||
@@ -878,23 +821,6 @@ class FeishuChannel(Channel):
|
|||||||
logger.info("[Feishu] empty text, ignoring message")
|
logger.info("[Feishu] empty text, ignoring message")
|
||||||
return
|
return
|
||||||
|
|
||||||
connect_code = extract_connect_code(text)
|
|
||||||
if connect_code and self._connection_repo is not None:
|
|
||||||
if self._main_loop and self._main_loop.is_running():
|
|
||||||
fut = asyncio.run_coroutine_threadsafe(
|
|
||||||
self._bind_connection_from_connect_code(
|
|
||||||
message_id=msg_id,
|
|
||||||
chat_id=chat_id,
|
|
||||||
user_id=sender_id,
|
|
||||||
code=connect_code,
|
|
||||||
),
|
|
||||||
self._main_loop,
|
|
||||||
)
|
|
||||||
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
|
|
||||||
else:
|
|
||||||
logger.warning("[Feishu] main loop not running, cannot bind channel connection")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Only treat known slash commands as commands; absolute paths and
|
# Only treat known slash commands as commands; absolute paths and
|
||||||
# other slash-prefixed text should be handled as normal chat.
|
# other slash-prefixed text should be handled as normal chat.
|
||||||
if _is_feishu_command(text):
|
if _is_feishu_command(text):
|
||||||
|
|||||||
+45
-292
@@ -8,7 +8,6 @@ import mimetypes
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from collections.abc import Awaitable, Callable, Mapping
|
from collections.abc import Awaitable, Callable, Mapping
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -27,13 +26,8 @@ from app.channels.message_bus import (
|
|||||||
from app.channels.store import ChannelStore
|
from app.channels.store import ChannelStore
|
||||||
from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token
|
from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token
|
||||||
from app.gateway.internal_auth import create_internal_auth_headers
|
from app.gateway.internal_auth import create_internal_auth_headers
|
||||||
from deerflow.config.agents_config import load_agent_config
|
|
||||||
from deerflow.config.paths import make_safe_user_id
|
from deerflow.config.paths import make_safe_user_id
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
from deerflow.skills.slash import parse_slash_skill_reference
|
|
||||||
from deerflow.skills.storage import get_or_new_skill_storage
|
|
||||||
from deerflow.skills.storage.skill_storage import SkillStorage
|
|
||||||
from deerflow.utils.messages import ORIGINAL_USER_CONTENT_KEY
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -49,11 +43,6 @@ DEFAULT_RUN_CONTEXT: dict[str, Any] = {
|
|||||||
"subagent_enabled": False,
|
"subagent_enabled": False,
|
||||||
}
|
}
|
||||||
STREAM_UPDATE_MIN_INTERVAL_SECONDS = 0.35
|
STREAM_UPDATE_MIN_INTERVAL_SECONDS = 0.35
|
||||||
# Stream modes requested from the runtime, and the SSE event names under which
|
|
||||||
# the message-tuple stream may arrive: the embedded runtime (and LangGraph
|
|
||||||
# Platform) deliver the requested "messages-tuple" mode as event "messages".
|
|
||||||
STREAM_MODES = ["messages-tuple", "values"]
|
|
||||||
MESSAGE_STREAM_EVENTS = ("messages-tuple", "messages")
|
|
||||||
THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again."
|
THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again."
|
||||||
|
|
||||||
CHANNEL_CAPABILITIES = {
|
CHANNEL_CAPABILITIES = {
|
||||||
@@ -61,7 +50,7 @@ CHANNEL_CAPABILITIES = {
|
|||||||
"discord": {"supports_streaming": False},
|
"discord": {"supports_streaming": False},
|
||||||
"feishu": {"supports_streaming": True},
|
"feishu": {"supports_streaming": True},
|
||||||
"slack": {"supports_streaming": False},
|
"slack": {"supports_streaming": False},
|
||||||
"telegram": {"supports_streaming": True},
|
"telegram": {"supports_streaming": False},
|
||||||
"wechat": {"supports_streaming": False},
|
"wechat": {"supports_streaming": False},
|
||||||
"wecom": {"supports_streaming": True},
|
"wecom": {"supports_streaming": True},
|
||||||
}
|
}
|
||||||
@@ -135,16 +124,6 @@ class InvalidChannelSessionConfigError(ValueError):
|
|||||||
"""Raised when IM channel session overrides contain invalid agent config."""
|
"""Raised when IM channel session overrides contain invalid agent config."""
|
||||||
|
|
||||||
|
|
||||||
class SlashSkillCommandResolutionError(RuntimeError):
|
|
||||||
"""Raised when IM slash-skill command resolution cannot complete safely."""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
|
||||||
class _SlashSkillCommandResolution:
|
|
||||||
route_to_chat: bool = False
|
|
||||||
failure_message: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def _is_thread_busy_error(exc: BaseException | None) -> bool:
|
def _is_thread_busy_error(exc: BaseException | None) -> bool:
|
||||||
if exc is None:
|
if exc is None:
|
||||||
return False
|
return False
|
||||||
@@ -279,22 +258,6 @@ def _response_metadata(base_metadata: dict[str, Any], *, pending_clarification:
|
|||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
def _thread_channel_metadata(msg: InboundMessage) -> dict[str, Any]:
|
|
||||||
channel_source: dict[str, Any] = {
|
|
||||||
"type": "im_channel",
|
|
||||||
"provider": msg.channel_name,
|
|
||||||
"chat_id": msg.chat_id,
|
|
||||||
}
|
|
||||||
if msg.topic_id:
|
|
||||||
channel_source["topic_id"] = msg.topic_id
|
|
||||||
if msg.thread_ts:
|
|
||||||
channel_source["thread_ts"] = msg.thread_ts
|
|
||||||
if msg.connection_id:
|
|
||||||
channel_source["connection_id"] = msg.connection_id
|
|
||||||
|
|
||||||
return {"channel_source": channel_source}
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_text_content(content: Any) -> str:
|
def _extract_text_content(content: Any) -> str:
|
||||||
"""Extract text from a streaming payload content field."""
|
"""Extract text from a streaming payload content field."""
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
@@ -447,83 +410,6 @@ def _format_artifact_text(artifacts: list[str]) -> str:
|
|||||||
_OUTPUTS_VIRTUAL_PREFIX = "/mnt/user-data/outputs/"
|
_OUTPUTS_VIRTUAL_PREFIX = "/mnt/user-data/outputs/"
|
||||||
|
|
||||||
|
|
||||||
def _unknown_command_reply(command: str | None = None) -> str:
|
|
||||||
available = " | ".join(sorted(KNOWN_CHANNEL_COMMANDS))
|
|
||||||
if command:
|
|
||||||
return f"Unknown command: /{command}. Available commands: {available}"
|
|
||||||
return f"Unknown command. Available commands: {available}"
|
|
||||||
|
|
||||||
|
|
||||||
def _human_input_message(content: str, *, original_content: str | None = None) -> dict[str, Any]:
|
|
||||||
message: dict[str, Any] = {"role": "human", "content": content}
|
|
||||||
if original_content is not None and original_content != content:
|
|
||||||
message["additional_kwargs"] = {ORIGINAL_USER_CONTENT_KEY: original_content}
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
def _auth_disabled_owner_user_id() -> str | None:
|
|
||||||
try:
|
|
||||||
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Unable to inspect auth-disabled mode for channel owner fallback", exc_info=True)
|
|
||||||
return None
|
|
||||||
return AUTH_DISABLED_USER_ID if is_auth_disabled() else None
|
|
||||||
|
|
||||||
|
|
||||||
def _effective_owner_user_id(msg: InboundMessage) -> str | None:
|
|
||||||
return _auth_disabled_owner_user_id() or msg.owner_user_id
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_effective_owner(msg: InboundMessage) -> InboundMessage:
|
|
||||||
owner_user_id = _effective_owner_user_id(msg)
|
|
||||||
if owner_user_id:
|
|
||||||
msg.owner_user_id = owner_user_id
|
|
||||||
return msg
|
|
||||||
|
|
||||||
|
|
||||||
def _owner_headers(msg: InboundMessage) -> dict[str, str] | None:
|
|
||||||
owner_user_id = _effective_owner_user_id(msg)
|
|
||||||
if not owner_user_id:
|
|
||||||
return None
|
|
||||||
return create_internal_auth_headers(owner_user_id=owner_user_id)
|
|
||||||
|
|
||||||
|
|
||||||
def _safe_user_id_for_run(raw_user_id: str) -> str:
|
|
||||||
from deerflow.config.paths import get_paths
|
|
||||||
|
|
||||||
try:
|
|
||||||
return get_paths().prepare_user_dir_for_raw_id(raw_user_id)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to prepare channel run user directory")
|
|
||||||
return make_safe_user_id(raw_user_id)
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_slash_skill_command(
|
|
||||||
text: str,
|
|
||||||
available_skills: set[str] | None = None,
|
|
||||||
storage: SkillStorage | Callable[[], SkillStorage] | None = None,
|
|
||||||
) -> _SlashSkillCommandResolution | None:
|
|
||||||
reference = parse_slash_skill_reference(text)
|
|
||||||
if reference is None:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
resolved_storage = storage() if callable(storage) else storage or get_or_new_skill_storage()
|
|
||||||
skills = resolved_storage.load_skills(enabled_only=False)
|
|
||||||
|
|
||||||
skill = next((candidate for candidate in skills if candidate.name == reference.name), None)
|
|
||||||
if skill is None:
|
|
||||||
return None
|
|
||||||
if not skill.enabled:
|
|
||||||
return _SlashSkillCommandResolution(failure_message=f"Skill `/{reference.name}` is installed but disabled. Enable it before using slash activation.")
|
|
||||||
if available_skills is not None and reference.name not in available_skills:
|
|
||||||
return _SlashSkillCommandResolution(failure_message=f"Skill `/{reference.name}` is not available for this agent.")
|
|
||||||
|
|
||||||
return _SlashSkillCommandResolution(route_to_chat=True)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.exception("[Manager] failed to resolve slash skill command")
|
|
||||||
raise SlashSkillCommandResolutionError("Failed to resolve slash skill command. Please check the skill configuration.") from exc
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedAttachment]:
|
def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedAttachment]:
|
||||||
"""Resolve virtual artifact paths to host filesystem paths with metadata.
|
"""Resolve virtual artifact paths to host filesystem paths with metadata.
|
||||||
|
|
||||||
@@ -613,14 +499,8 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
|
|||||||
write_upload_file_no_symlink,
|
write_upload_file_no_symlink,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _prepare_uploads_dir() -> tuple[Path, set[str]]:
|
uploads_dir = ensure_uploads_dir(thread_id)
|
||||||
# Worker thread: ensure_uploads_dir's mkdir and the iterdir enumeration are
|
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
|
||||||
# blocking filesystem IO that must stay off the event loop.
|
|
||||||
target = ensure_uploads_dir(thread_id)
|
|
||||||
existing = {entry.name for entry in target.iterdir() if entry.is_file()}
|
|
||||||
return target, existing
|
|
||||||
|
|
||||||
uploads_dir, seen_names = await asyncio.to_thread(_prepare_uploads_dir)
|
|
||||||
|
|
||||||
created: list[dict[str, Any]] = []
|
created: list[dict[str, Any]] = []
|
||||||
file_reader = INBOUND_FILE_READERS.get(msg.channel_name, _read_http_inbound_file)
|
file_reader = INBOUND_FILE_READERS.get(msg.channel_name, _read_http_inbound_file)
|
||||||
@@ -668,7 +548,7 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
|
|||||||
|
|
||||||
dest = uploads_dir / safe_name
|
dest = uploads_dir / safe_name
|
||||||
try:
|
try:
|
||||||
dest = await asyncio.to_thread(write_upload_file_no_symlink, uploads_dir, safe_name, data)
|
dest = write_upload_file_no_symlink(uploads_dir, safe_name, data)
|
||||||
except UnsafeUploadPathError:
|
except UnsafeUploadPathError:
|
||||||
logger.warning("[Manager] skipping inbound file with unsafe destination: %s", safe_name)
|
logger.warning("[Manager] skipping inbound file with unsafe destination: %s", safe_name)
|
||||||
continue
|
continue
|
||||||
@@ -734,7 +614,6 @@ class ChannelManager:
|
|||||||
assistant_id: str = DEFAULT_ASSISTANT_ID,
|
assistant_id: str = DEFAULT_ASSISTANT_ID,
|
||||||
default_session: dict[str, Any] | None = None,
|
default_session: dict[str, Any] | None = None,
|
||||||
channel_sessions: dict[str, Any] | None = None,
|
channel_sessions: dict[str, Any] | None = None,
|
||||||
connection_repo: Any | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
self.store = store
|
self.store = store
|
||||||
@@ -744,10 +623,7 @@ class ChannelManager:
|
|||||||
self._assistant_id = assistant_id
|
self._assistant_id = assistant_id
|
||||||
self._default_session = _as_dict(default_session)
|
self._default_session = _as_dict(default_session)
|
||||||
self._channel_sessions = dict(channel_sessions or {})
|
self._channel_sessions = dict(channel_sessions or {})
|
||||||
self._connection_repo = connection_repo
|
|
||||||
self._client = None # lazy init — langgraph_sdk async client
|
self._client = None # lazy init — langgraph_sdk async client
|
||||||
self._channel_metadata_synced: set[str] = set()
|
|
||||||
self._skill_storage: SkillStorage | None = None
|
|
||||||
self._csrf_token = generate_csrf_token()
|
self._csrf_token = generate_csrf_token()
|
||||||
self._semaphore: asyncio.Semaphore | None = None
|
self._semaphore: asyncio.Semaphore | None = None
|
||||||
self._running = False
|
self._running = False
|
||||||
@@ -795,17 +671,12 @@ class ChannelManager:
|
|||||||
configurable["checkpoint_ns"] = ""
|
configurable["checkpoint_ns"] = ""
|
||||||
configurable["thread_id"] = thread_id
|
configurable["thread_id"] = thread_id
|
||||||
|
|
||||||
# ``user_id`` drives DeerFlow-owned memory, files, and thread buckets.
|
# ``user_id`` drives user-scoped filesystem buckets that only accept
|
||||||
# For browser-connected IM channels, prefer the DeerFlow account that
|
# ``[A-Za-z0-9_-]``, so normalize the channel id and keep the raw value
|
||||||
# owns the connection. Preserve the raw platform user under
|
# under ``channel_user_id`` for platform-facing lookups.
|
||||||
# ``channel_user_id`` for platform-facing lookups and audits.
|
|
||||||
run_context_identity: dict[str, Any] = {"thread_id": thread_id}
|
run_context_identity: dict[str, Any] = {"thread_id": thread_id}
|
||||||
owner_user_id = _effective_owner_user_id(msg)
|
|
||||||
if owner_user_id:
|
|
||||||
run_context_identity["user_id"] = _safe_user_id_for_run(owner_user_id)
|
|
||||||
elif msg.user_id:
|
|
||||||
run_context_identity["user_id"] = _safe_user_id_for_run(msg.user_id)
|
|
||||||
if msg.user_id:
|
if msg.user_id:
|
||||||
|
run_context_identity["user_id"] = make_safe_user_id(msg.user_id)
|
||||||
run_context_identity["channel_user_id"] = msg.user_id
|
run_context_identity["channel_user_id"] = msg.user_id
|
||||||
|
|
||||||
run_context = _merge_dicts(
|
run_context = _merge_dicts(
|
||||||
@@ -825,21 +696,6 @@ class ChannelManager:
|
|||||||
|
|
||||||
return assistant_id, run_config, run_context
|
return assistant_id, run_config, run_context
|
||||||
|
|
||||||
def _resolve_available_skill_names(self, msg: InboundMessage) -> set[str] | None:
|
|
||||||
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or ""
|
|
||||||
_, _, run_context = self._resolve_run_params(msg, thread_id)
|
|
||||||
if run_context.get("is_bootstrap"):
|
|
||||||
return {"bootstrap"}
|
|
||||||
|
|
||||||
agent_name = run_context.get("agent_name")
|
|
||||||
if not isinstance(agent_name, str) or not agent_name.strip():
|
|
||||||
return None
|
|
||||||
|
|
||||||
agent_config = load_agent_config(_normalize_custom_agent_name(agent_name))
|
|
||||||
if agent_config and agent_config.skills is not None:
|
|
||||||
return set(agent_config.skills)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# -- LangGraph SDK client (lazy) ----------------------------------------
|
# -- LangGraph SDK client (lazy) ----------------------------------------
|
||||||
|
|
||||||
def _get_client(self):
|
def _get_client(self):
|
||||||
@@ -857,11 +713,6 @@ class ChannelManager:
|
|||||||
)
|
)
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
def _get_skill_storage(self) -> SkillStorage:
|
|
||||||
if self._skill_storage is None:
|
|
||||||
self._skill_storage = get_or_new_skill_storage()
|
|
||||||
return self._skill_storage
|
|
||||||
|
|
||||||
# -- lifecycle ---------------------------------------------------------
|
# -- lifecycle ---------------------------------------------------------
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
@@ -917,7 +768,6 @@ class ChannelManager:
|
|||||||
logger.error("[Manager] unhandled error in message task: %s", exc, exc_info=exc)
|
logger.error("[Manager] unhandled error in message task: %s", exc, exc_info=exc)
|
||||||
|
|
||||||
async def _handle_message(self, msg: InboundMessage) -> None:
|
async def _handle_message(self, msg: InboundMessage) -> None:
|
||||||
msg = _apply_effective_owner(msg)
|
|
||||||
async with self._semaphore:
|
async with self._semaphore:
|
||||||
try:
|
try:
|
||||||
if msg.msg_type == InboundMessageType.COMMAND:
|
if msg.msg_type == InboundMessageType.COMMAND:
|
||||||
@@ -932,14 +782,6 @@ class ChannelManager:
|
|||||||
exc,
|
exc,
|
||||||
)
|
)
|
||||||
await self._send_error(msg, str(exc))
|
await self._send_error(msg, str(exc))
|
||||||
except SlashSkillCommandResolutionError as exc:
|
|
||||||
logger.warning(
|
|
||||||
"Slash skill command resolution failed for %s (chat=%s): %s",
|
|
||||||
msg.channel_name,
|
|
||||||
msg.chat_id,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
await self._send_error(msg, str(exc))
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Error handling message from %s (chat=%s)",
|
"Error handling message from %s (chat=%s)",
|
||||||
@@ -950,27 +792,10 @@ class ChannelManager:
|
|||||||
|
|
||||||
# -- chat handling -----------------------------------------------------
|
# -- chat handling -----------------------------------------------------
|
||||||
|
|
||||||
async def _lookup_thread_id(self, msg: InboundMessage) -> str | None:
|
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
||||||
if msg.connection_id and self._connection_repo is not None:
|
"""Create a new thread through Gateway and store the mapping."""
|
||||||
return await self._connection_repo.get_thread_id(
|
thread = await client.threads.create()
|
||||||
msg.connection_id,
|
thread_id = thread["thread_id"]
|
||||||
msg.chat_id,
|
|
||||||
msg.topic_id,
|
|
||||||
)
|
|
||||||
return self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
|
||||||
|
|
||||||
async def _store_thread_id(self, msg: InboundMessage, thread_id: str) -> None:
|
|
||||||
if msg.connection_id and msg.owner_user_id and self._connection_repo is not None:
|
|
||||||
await self._connection_repo.set_thread_id(
|
|
||||||
connection_id=msg.connection_id,
|
|
||||||
owner_user_id=msg.owner_user_id,
|
|
||||||
provider=msg.channel_name,
|
|
||||||
external_conversation_id=msg.chat_id,
|
|
||||||
external_topic_id=msg.topic_id,
|
|
||||||
thread_id=thread_id,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
self.store.set_thread_id(
|
self.store.set_thread_id(
|
||||||
msg.channel_name,
|
msg.channel_name,
|
||||||
msg.chat_id,
|
msg.chat_id,
|
||||||
@@ -978,49 +803,18 @@ class ChannelManager:
|
|||||||
topic_id=msg.topic_id,
|
topic_id=msg.topic_id,
|
||||||
user_id=msg.user_id,
|
user_id=msg.user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
|
||||||
"""Create a new thread through Gateway and store the mapping."""
|
|
||||||
metadata = _thread_channel_metadata(msg)
|
|
||||||
owner_headers = _owner_headers(msg)
|
|
||||||
if owner_headers:
|
|
||||||
thread = await client.threads.create(metadata=metadata, headers=owner_headers)
|
|
||||||
else:
|
|
||||||
thread = await client.threads.create(metadata=metadata)
|
|
||||||
thread_id = thread["thread_id"]
|
|
||||||
await self._store_thread_id(msg, thread_id)
|
|
||||||
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
|
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
|
||||||
return thread_id
|
return thread_id
|
||||||
|
|
||||||
async def _update_thread_channel_metadata(self, client, msg: InboundMessage, thread_id: str) -> None:
|
|
||||||
"""Best-effort source metadata backfill for existing IM-created threads."""
|
|
||||||
# The metadata (provider/chat/topic) is constant for a thread, so one
|
|
||||||
# successful backfill per manager lifetime is enough — skip the
|
|
||||||
# redundant PATCH on every subsequent inbound message.
|
|
||||||
if thread_id in self._channel_metadata_synced:
|
|
||||||
return
|
|
||||||
update_kwargs: dict[str, Any] = {"metadata": _thread_channel_metadata(msg)}
|
|
||||||
if owner_headers := _owner_headers(msg):
|
|
||||||
update_kwargs["headers"] = owner_headers
|
|
||||||
try:
|
|
||||||
await client.threads.update(thread_id, **update_kwargs)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("[Manager] failed to update channel metadata for thread_id=%s", thread_id, exc_info=True)
|
|
||||||
return
|
|
||||||
if len(self._channel_metadata_synced) > 4096:
|
|
||||||
self._channel_metadata_synced.clear()
|
|
||||||
self._channel_metadata_synced.add(thread_id)
|
|
||||||
|
|
||||||
async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None:
|
async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
|
|
||||||
# Look up existing DeerFlow thread.
|
# Look up existing DeerFlow thread.
|
||||||
# topic_id may be None (e.g. Telegram private chats) — the store
|
# topic_id may be None (e.g. Telegram private chats) — the store
|
||||||
# handles this by using the "channel:chat_id" key without a topic suffix.
|
# handles this by using the "channel:chat_id" key without a topic suffix.
|
||||||
thread_id = await self._lookup_thread_id(msg)
|
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
||||||
if thread_id:
|
if thread_id:
|
||||||
logger.info("[Manager] reusing thread: thread_id=%s for topic_id=%s", thread_id, msg.topic_id)
|
logger.info("[Manager] reusing thread: thread_id=%s for topic_id=%s", thread_id, msg.topic_id)
|
||||||
await self._update_thread_channel_metadata(client, msg, thread_id)
|
|
||||||
|
|
||||||
# No existing thread found — create a new one
|
# No existing thread found — create a new one
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
@@ -1042,11 +836,9 @@ class ChannelManager:
|
|||||||
if extra_context:
|
if extra_context:
|
||||||
run_context.update(extra_context)
|
run_context.update(extra_context)
|
||||||
|
|
||||||
original_text = msg.text
|
|
||||||
uploaded = await _ingest_inbound_files(thread_id, msg)
|
uploaded = await _ingest_inbound_files(thread_id, msg)
|
||||||
if uploaded:
|
if uploaded:
|
||||||
msg.text = f"{_format_uploaded_files_block(uploaded)}\n\n{msg.text}".strip()
|
msg.text = f"{_format_uploaded_files_block(uploaded)}\n\n{msg.text}".strip()
|
||||||
human_message = _human_input_message(msg.text, original_content=original_text)
|
|
||||||
|
|
||||||
if self._channel_supports_streaming(msg.channel_name):
|
if self._channel_supports_streaming(msg.channel_name):
|
||||||
await self._handle_streaming_chat(
|
await self._handle_streaming_chat(
|
||||||
@@ -1056,24 +848,18 @@ class ChannelManager:
|
|||||||
assistant_id,
|
assistant_id,
|
||||||
run_config,
|
run_config,
|
||||||
run_context,
|
run_context,
|
||||||
human_message,
|
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
||||||
run_kwargs: dict[str, Any] = {
|
|
||||||
"input": {"messages": [human_message]},
|
|
||||||
"config": run_config,
|
|
||||||
"context": run_context,
|
|
||||||
"multitask_strategy": "reject",
|
|
||||||
}
|
|
||||||
if owner_headers := _owner_headers(msg):
|
|
||||||
run_kwargs["headers"] = owner_headers
|
|
||||||
try:
|
try:
|
||||||
result = await client.runs.wait(
|
result = await client.runs.wait(
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id,
|
assistant_id,
|
||||||
**run_kwargs,
|
input={"messages": [{"role": "human", "content": msg.text}]},
|
||||||
|
config=run_config,
|
||||||
|
context=run_context,
|
||||||
|
multitask_strategy="reject",
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if _is_thread_busy_error(exc):
|
if _is_thread_busy_error(exc):
|
||||||
@@ -1110,8 +896,6 @@ class ChannelManager:
|
|||||||
artifacts=artifacts,
|
artifacts=artifacts,
|
||||||
attachments=attachments,
|
attachments=attachments,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
connection_id=msg.connection_id,
|
|
||||||
owner_user_id=msg.owner_user_id,
|
|
||||||
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
|
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
|
||||||
)
|
)
|
||||||
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
|
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
|
||||||
@@ -1125,7 +909,6 @@ class ChannelManager:
|
|||||||
assistant_id: str,
|
assistant_id: str,
|
||||||
run_config: dict[str, Any],
|
run_config: dict[str, Any],
|
||||||
run_context: dict[str, Any],
|
run_context: dict[str, Any],
|
||||||
human_message: dict[str, Any],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.info("[Manager] invoking runs.stream(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
logger.info("[Manager] invoking runs.stream(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
||||||
|
|
||||||
@@ -1136,26 +919,21 @@ class ChannelManager:
|
|||||||
last_published_text = ""
|
last_published_text = ""
|
||||||
last_publish_at = 0.0
|
last_publish_at = 0.0
|
||||||
stream_error: BaseException | None = None
|
stream_error: BaseException | None = None
|
||||||
stream_kwargs: dict[str, Any] = {
|
|
||||||
"input": {"messages": [human_message]},
|
|
||||||
"config": run_config,
|
|
||||||
"context": run_context,
|
|
||||||
"stream_mode": list(STREAM_MODES),
|
|
||||||
"multitask_strategy": "reject",
|
|
||||||
}
|
|
||||||
if owner_headers := _owner_headers(msg):
|
|
||||||
stream_kwargs["headers"] = owner_headers
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in client.runs.stream(
|
async for chunk in client.runs.stream(
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id,
|
assistant_id,
|
||||||
**stream_kwargs,
|
input={"messages": [{"role": "human", "content": msg.text}]},
|
||||||
|
config=run_config,
|
||||||
|
context=run_context,
|
||||||
|
stream_mode=["messages-tuple", "values"],
|
||||||
|
multitask_strategy="reject",
|
||||||
):
|
):
|
||||||
event = getattr(chunk, "event", "")
|
event = getattr(chunk, "event", "")
|
||||||
data = getattr(chunk, "data", None)
|
data = getattr(chunk, "data", None)
|
||||||
|
|
||||||
if event in MESSAGE_STREAM_EVENTS:
|
if event == "messages-tuple":
|
||||||
accumulated_text, current_message_id = _accumulate_stream_text(streamed_buffers, current_message_id, data)
|
accumulated_text, current_message_id = _accumulate_stream_text(streamed_buffers, current_message_id, data)
|
||||||
if accumulated_text:
|
if accumulated_text:
|
||||||
latest_text = accumulated_text
|
latest_text = accumulated_text
|
||||||
@@ -1180,8 +958,6 @@ class ChannelManager:
|
|||||||
text=latest_text,
|
text=latest_text,
|
||||||
is_final=False,
|
is_final=False,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
connection_id=msg.connection_id,
|
|
||||||
owner_user_id=msg.owner_user_id,
|
|
||||||
metadata=_response_metadata(msg.metadata),
|
metadata=_response_metadata(msg.metadata),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -1228,8 +1004,6 @@ class ChannelManager:
|
|||||||
attachments=attachments,
|
attachments=attachments,
|
||||||
is_final=True,
|
is_final=True,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
connection_id=msg.connection_id,
|
|
||||||
owner_user_id=msg.owner_user_id,
|
|
||||||
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
|
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -1237,20 +1011,11 @@ class ChannelManager:
|
|||||||
# -- command handling --------------------------------------------------
|
# -- command handling --------------------------------------------------
|
||||||
|
|
||||||
async def _handle_command(self, msg: InboundMessage) -> None:
|
async def _handle_command(self, msg: InboundMessage) -> None:
|
||||||
raw_text = msg.text
|
text = msg.text.strip()
|
||||||
text = raw_text.strip()
|
|
||||||
parts = text.split(maxsplit=1)
|
parts = text.split(maxsplit=1)
|
||||||
reply: str | None = None
|
command = parts[0].lower().lstrip("/")
|
||||||
if not parts:
|
|
||||||
command = None
|
|
||||||
reply = _unknown_command_reply()
|
|
||||||
else:
|
|
||||||
command = parts[0].lower().removeprefix("/")
|
|
||||||
|
|
||||||
if reply is None and not raw_text.startswith("/"):
|
if command == "bootstrap":
|
||||||
reply = _unknown_command_reply(command)
|
|
||||||
|
|
||||||
if reply is None and command == "bootstrap":
|
|
||||||
from dataclasses import replace as _dc_replace
|
from dataclasses import replace as _dc_replace
|
||||||
|
|
||||||
chat_text = parts[1] if len(parts) > 1 else "Initialize workspace"
|
chat_text = parts[1] if len(parts) > 1 else "Initialize workspace"
|
||||||
@@ -1258,19 +1023,27 @@ class ChannelManager:
|
|||||||
await self._handle_chat(chat_msg, extra_context={"is_bootstrap": True})
|
await self._handle_chat(chat_msg, extra_context={"is_bootstrap": True})
|
||||||
return
|
return
|
||||||
|
|
||||||
if reply is None and command == "new":
|
if command == "new":
|
||||||
# Create a new thread through Gateway
|
# Create a new thread through Gateway
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
await self._create_thread(client, msg)
|
thread = await client.threads.create()
|
||||||
|
new_thread_id = thread["thread_id"]
|
||||||
|
self.store.set_thread_id(
|
||||||
|
msg.channel_name,
|
||||||
|
msg.chat_id,
|
||||||
|
new_thread_id,
|
||||||
|
topic_id=msg.topic_id,
|
||||||
|
user_id=msg.user_id,
|
||||||
|
)
|
||||||
reply = "New conversation started."
|
reply = "New conversation started."
|
||||||
elif reply is None and command == "status":
|
elif command == "status":
|
||||||
thread_id = await self._lookup_thread_id(msg)
|
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
||||||
reply = f"Active thread: {thread_id}" if thread_id else "No active conversation."
|
reply = f"Active thread: {thread_id}" if thread_id else "No active conversation."
|
||||||
elif reply is None and command == "models":
|
elif command == "models":
|
||||||
reply = await self._fetch_gateway("/api/models", "models")
|
reply = await self._fetch_gateway("/api/models", "models")
|
||||||
elif reply is None and command == "memory":
|
elif command == "memory":
|
||||||
reply = await self._fetch_gateway("/api/memory", "memory")
|
reply = await self._fetch_gateway("/api/memory", "memory")
|
||||||
elif reply is None and command == "help":
|
elif command == "help":
|
||||||
reply = (
|
reply = (
|
||||||
"Available commands:\n"
|
"Available commands:\n"
|
||||||
"/bootstrap — Start a bootstrap session (enables agent setup)\n"
|
"/bootstrap — Start a bootstrap session (enables agent setup)\n"
|
||||||
@@ -1278,36 +1051,18 @@ class ChannelManager:
|
|||||||
"/status — Show current thread info\n"
|
"/status — Show current thread info\n"
|
||||||
"/models — List available models\n"
|
"/models — List available models\n"
|
||||||
"/memory — Show memory status\n"
|
"/memory — Show memory status\n"
|
||||||
"/<skill-name> <task> — Activate an enabled skill for one turn\n"
|
|
||||||
"/help — Show this help"
|
"/help — Show this help"
|
||||||
)
|
)
|
||||||
elif reply is None:
|
|
||||||
slash_resolution = await asyncio.to_thread(
|
|
||||||
lambda: _resolve_slash_skill_command(
|
|
||||||
raw_text,
|
|
||||||
self._resolve_available_skill_names(msg),
|
|
||||||
self._get_skill_storage,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if slash_resolution and slash_resolution.failure_message:
|
|
||||||
reply = slash_resolution.failure_message
|
|
||||||
elif slash_resolution and slash_resolution.route_to_chat:
|
|
||||||
from dataclasses import replace as _dc_replace
|
|
||||||
|
|
||||||
chat_msg = _dc_replace(msg, msg_type=InboundMessageType.CHAT)
|
|
||||||
await self._handle_chat(chat_msg)
|
|
||||||
return
|
|
||||||
else:
|
else:
|
||||||
reply = _unknown_command_reply(command)
|
available = " | ".join(sorted(KNOWN_CHANNEL_COMMANDS))
|
||||||
|
reply = f"Unknown command: /{command}. Available commands: {available}"
|
||||||
|
|
||||||
outbound = OutboundMessage(
|
outbound = OutboundMessage(
|
||||||
channel_name=msg.channel_name,
|
channel_name=msg.channel_name,
|
||||||
chat_id=msg.chat_id,
|
chat_id=msg.chat_id,
|
||||||
thread_id=await self._lookup_thread_id(msg) or "",
|
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
|
||||||
text=reply,
|
text=reply,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
connection_id=msg.connection_id,
|
|
||||||
owner_user_id=msg.owner_user_id,
|
|
||||||
metadata=_slim_metadata(msg.metadata),
|
metadata=_slim_metadata(msg.metadata),
|
||||||
)
|
)
|
||||||
await self.bus.publish_outbound(outbound)
|
await self.bus.publish_outbound(outbound)
|
||||||
@@ -1343,11 +1098,9 @@ class ChannelManager:
|
|||||||
outbound = OutboundMessage(
|
outbound = OutboundMessage(
|
||||||
channel_name=msg.channel_name,
|
channel_name=msg.channel_name,
|
||||||
chat_id=msg.chat_id,
|
chat_id=msg.chat_id,
|
||||||
thread_id=await self._lookup_thread_id(msg) or "",
|
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
|
||||||
text=error_text,
|
text=error_text,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
connection_id=msg.connection_id,
|
|
||||||
owner_user_id=msg.owner_user_id,
|
|
||||||
metadata=_slim_metadata(msg.metadata),
|
metadata=_slim_metadata(msg.metadata),
|
||||||
)
|
)
|
||||||
await self.bus.publish_outbound(outbound)
|
await self.bus.publish_outbound(outbound)
|
||||||
|
|||||||
@@ -44,12 +44,6 @@ class InboundMessage:
|
|||||||
Messages sharing the same ``topic_id`` within a ``chat_id`` will
|
Messages sharing the same ``topic_id`` within a ``chat_id`` will
|
||||||
reuse the same DeerFlow thread. When ``None``, each message
|
reuse the same DeerFlow thread. When ``None``, each message
|
||||||
creates a new thread (one-shot Q&A).
|
creates a new thread (one-shot Q&A).
|
||||||
connection_id: Optional DeerFlow channel connection id. When present,
|
|
||||||
conversation mapping is scoped by the connection instead of the
|
|
||||||
legacy global ``channel_name:chat_id[:topic_id]`` key.
|
|
||||||
owner_user_id: DeerFlow user id that owns the channel connection.
|
|
||||||
Platform user ids stay in ``user_id``.
|
|
||||||
workspace_id: Optional external workspace/guild/team id.
|
|
||||||
files: Optional list of file attachments (platform-specific dicts).
|
files: Optional list of file attachments (platform-specific dicts).
|
||||||
metadata: Arbitrary extra data from the channel.
|
metadata: Arbitrary extra data from the channel.
|
||||||
created_at: Unix timestamp when the message was created.
|
created_at: Unix timestamp when the message was created.
|
||||||
@@ -62,9 +56,6 @@ class InboundMessage:
|
|||||||
msg_type: InboundMessageType = InboundMessageType.CHAT
|
msg_type: InboundMessageType = InboundMessageType.CHAT
|
||||||
thread_ts: str | None = None
|
thread_ts: str | None = None
|
||||||
topic_id: str | None = None
|
topic_id: str | None = None
|
||||||
connection_id: str | None = None
|
|
||||||
owner_user_id: str | None = None
|
|
||||||
workspace_id: str | None = None
|
|
||||||
files: list[dict[str, Any]] = field(default_factory=list)
|
files: list[dict[str, Any]] = field(default_factory=list)
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
created_at: float = field(default_factory=time.time)
|
created_at: float = field(default_factory=time.time)
|
||||||
@@ -104,9 +95,6 @@ class OutboundMessage:
|
|||||||
is_final: Whether this is the final message in the response stream.
|
is_final: Whether this is the final message in the response stream.
|
||||||
thread_ts: Optional platform thread identifier for threaded replies.
|
thread_ts: Optional platform thread identifier for threaded replies.
|
||||||
metadata: Arbitrary extra data.
|
metadata: Arbitrary extra data.
|
||||||
connection_id: Optional DeerFlow channel connection id used for
|
|
||||||
connection-specific outbound credentials.
|
|
||||||
owner_user_id: DeerFlow user id that owns the channel connection.
|
|
||||||
created_at: Unix timestamp.
|
created_at: Unix timestamp.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -118,8 +106,6 @@ class OutboundMessage:
|
|||||||
attachments: list[ResolvedAttachment] = field(default_factory=list)
|
attachments: list[ResolvedAttachment] = field(default_factory=list)
|
||||||
is_final: bool = True
|
is_final: bool = True
|
||||||
thread_ts: str | None = None
|
thread_ts: str | None = None
|
||||||
connection_id: str | None = None
|
|
||||||
owner_user_id: str | None = None
|
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
created_at: float = field(default_factory=time.time)
|
created_at: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
|||||||
@@ -1,154 +0,0 @@
|
|||||||
"""Local persistence for runtime IM channel configuration."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import tempfile
|
|
||||||
import threading
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
RUNTIME_CHANNEL_DISABLED_FLAG = "_runtime_disabled"
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelRuntimeConfigStore:
|
|
||||||
"""JSON-backed store for channel credentials entered from the UI.
|
|
||||||
|
|
||||||
This intentionally mirrors ``ChannelStore``: local/private deployments get
|
|
||||||
durable runtime configuration without needing a public callback URL or a
|
|
||||||
config.yaml edit.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, path: str | Path | None = None) -> None:
|
|
||||||
if path is None:
|
|
||||||
from deerflow.config.paths import get_paths
|
|
||||||
|
|
||||||
path = Path(get_paths().base_dir) / "channels" / "runtime-config.json"
|
|
||||||
self._path = Path(path)
|
|
||||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
self._data: dict[str, dict[str, Any]] = self._load()
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
|
|
||||||
def _load(self) -> dict[str, dict[str, Any]]:
|
|
||||||
if self._path.exists():
|
|
||||||
try:
|
|
||||||
raw = json.loads(self._path.read_text(encoding="utf-8"))
|
|
||||||
except (json.JSONDecodeError, OSError):
|
|
||||||
logger.warning("Corrupt channel runtime config store at %s, starting fresh", self._path)
|
|
||||||
return {}
|
|
||||||
if isinstance(raw, dict):
|
|
||||||
return {str(name): dict(value) for name, value in raw.items() if isinstance(value, dict)}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def _save(self) -> None:
|
|
||||||
fd = tempfile.NamedTemporaryFile(
|
|
||||||
mode="w",
|
|
||||||
dir=self._path.parent,
|
|
||||||
suffix=".tmp",
|
|
||||||
delete=False,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
json.dump(self._data, fd, indent=2, ensure_ascii=False)
|
|
||||||
fd.close()
|
|
||||||
Path(fd.name).replace(self._path)
|
|
||||||
try:
|
|
||||||
self._path.chmod(0o600)
|
|
||||||
except OSError:
|
|
||||||
logger.debug("Unable to chmod channel runtime config store at %s", self._path, exc_info=True)
|
|
||||||
except BaseException:
|
|
||||||
fd.close()
|
|
||||||
Path(fd.name).unlink(missing_ok=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def load_all(self) -> dict[str, dict[str, Any]]:
|
|
||||||
with self._lock:
|
|
||||||
return {name: dict(config) for name, config in self._data.items()}
|
|
||||||
|
|
||||||
def get_provider_config(self, provider: str) -> dict[str, Any] | None:
|
|
||||||
with self._lock:
|
|
||||||
config = self._data.get(provider)
|
|
||||||
return dict(config) if isinstance(config, dict) else None
|
|
||||||
|
|
||||||
def set_provider_config(self, provider: str, config: dict[str, Any]) -> None:
|
|
||||||
with self._lock:
|
|
||||||
self._data[provider] = dict(config)
|
|
||||||
self._save()
|
|
||||||
|
|
||||||
def set_provider_disconnected(self, provider: str) -> None:
|
|
||||||
with self._lock:
|
|
||||||
self._data[provider] = {
|
|
||||||
"enabled": False,
|
|
||||||
RUNTIME_CHANNEL_DISABLED_FLAG: True,
|
|
||||||
}
|
|
||||||
self._save()
|
|
||||||
|
|
||||||
def remove_provider_config(self, provider: str) -> bool:
|
|
||||||
with self._lock:
|
|
||||||
if provider not in self._data:
|
|
||||||
return False
|
|
||||||
del self._data[provider]
|
|
||||||
self._save()
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def _provider_enabled(channel_connections_config: Any, provider: str) -> bool:
|
|
||||||
provider_config = getattr(channel_connections_config, provider, None)
|
|
||||||
return bool(getattr(provider_config, "enabled", False))
|
|
||||||
|
|
||||||
|
|
||||||
def _runtime_channel_disconnected(runtime_config: dict[str, Any]) -> bool:
|
|
||||||
return runtime_config.get(RUNTIME_CHANNEL_DISABLED_FLAG) is True and runtime_config.get("enabled") is False
|
|
||||||
|
|
||||||
|
|
||||||
def merge_runtime_channel_configs(
|
|
||||||
channels_config: dict[str, Any],
|
|
||||||
channel_connections_config: Any,
|
|
||||||
*,
|
|
||||||
store: ChannelRuntimeConfigStore | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Merge persisted runtime provider config into ``channels_config`` in-place."""
|
|
||||||
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
|
|
||||||
return
|
|
||||||
|
|
||||||
runtime_store = store or ChannelRuntimeConfigStore()
|
|
||||||
for provider, runtime_config in runtime_store.load_all().items():
|
|
||||||
if not _provider_enabled(channel_connections_config, provider):
|
|
||||||
continue
|
|
||||||
if _runtime_channel_disconnected(runtime_config):
|
|
||||||
channels_config.pop(provider, None)
|
|
||||||
continue
|
|
||||||
existing = channels_config.get(provider)
|
|
||||||
merged = dict(runtime_config)
|
|
||||||
if isinstance(existing, dict):
|
|
||||||
merged.update(existing)
|
|
||||||
channels_config[provider] = merged
|
|
||||||
|
|
||||||
|
|
||||||
def apply_runtime_connection_config(
|
|
||||||
channel_connections_config: Any,
|
|
||||||
*,
|
|
||||||
store: ChannelRuntimeConfigStore | None = None,
|
|
||||||
) -> Any:
|
|
||||||
"""Apply persisted connection metadata that lives outside ``channels``.
|
|
||||||
|
|
||||||
Telegram uses a bot username for deep links; UI-entered values are stored
|
|
||||||
with the runtime channel config so local restarts keep the provider
|
|
||||||
configured.
|
|
||||||
"""
|
|
||||||
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
|
|
||||||
return channel_connections_config
|
|
||||||
|
|
||||||
runtime_store = store or ChannelRuntimeConfigStore()
|
|
||||||
telegram_runtime_config = runtime_store.get_provider_config("telegram")
|
|
||||||
bot_username = ""
|
|
||||||
if isinstance(telegram_runtime_config, dict):
|
|
||||||
bot_username = str(telegram_runtime_config.get("bot_username") or "").strip()
|
|
||||||
if not bot_username or not _provider_enabled(channel_connections_config, "telegram"):
|
|
||||||
return channel_connections_config
|
|
||||||
|
|
||||||
config = channel_connections_config.model_copy(deep=True)
|
|
||||||
config.telegram.bot_username = bot_username
|
|
||||||
return config
|
|
||||||
+24
-169
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
@@ -10,7 +9,6 @@ from typing import TYPE_CHECKING, Any
|
|||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
|
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
|
||||||
from app.channels.message_bus import MessageBus
|
from app.channels.message_bus import MessageBus
|
||||||
from app.channels.runtime_config_store import merge_runtime_channel_configs
|
|
||||||
from app.channels.store import ChannelStore
|
from app.channels.store import ChannelStore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -44,11 +42,6 @@ _CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
|
|||||||
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
|
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
|
||||||
|
|
||||||
|
|
||||||
def _channel_has_credentials(name: str, channel_config: dict[str, Any]) -> bool:
|
|
||||||
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
|
|
||||||
return any(not isinstance(channel_config.get(key), bool) and channel_config.get(key) is not None and str(channel_config[key]).strip() for key in cred_keys)
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str, default: str) -> str:
|
def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str, default: str) -> str:
|
||||||
value = config.pop(config_key, None)
|
value = config.pop(config_key, None)
|
||||||
if isinstance(value, str) and value.strip():
|
if isinstance(value, str) and value.strip():
|
||||||
@@ -59,30 +52,6 @@ def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str,
|
|||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
def _merge_channel_connection_runtime_config(channels_config: dict[str, Any], app_config: AppConfig) -> None:
|
|
||||||
connection_config = getattr(app_config, "channel_connections", None)
|
|
||||||
merge_runtime_channel_configs(channels_config, connection_config)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_connection_repo(app_config: AppConfig):
|
|
||||||
connection_config = getattr(app_config, "channel_connections", None)
|
|
||||||
if connection_config is None or not getattr(connection_config, "enabled", False):
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
|
||||||
from deerflow.persistence.engine import get_session_factory
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to import channel connection repository")
|
|
||||||
return None
|
|
||||||
|
|
||||||
session_factory = get_session_factory()
|
|
||||||
if session_factory is None:
|
|
||||||
logger.warning("Channel connections are enabled but database persistence is not available")
|
|
||||||
return None
|
|
||||||
return ChannelConnectionRepository(session_factory)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelService:
|
class ChannelService:
|
||||||
"""Manages the lifecycle of all configured IM channels.
|
"""Manages the lifecycle of all configured IM channels.
|
||||||
|
|
||||||
@@ -90,10 +59,9 @@ class ChannelService:
|
|||||||
instantiates enabled channels, and starts the ChannelManager dispatcher.
|
instantiates enabled channels, and starts the ChannelManager dispatcher.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels_config: dict[str, Any] | None = None, *, connection_repo: Any | None = None) -> None:
|
def __init__(self, channels_config: dict[str, Any] | None = None) -> None:
|
||||||
self.bus = MessageBus()
|
self.bus = MessageBus()
|
||||||
self.store = ChannelStore()
|
self.store = ChannelStore()
|
||||||
self._connection_repo = connection_repo
|
|
||||||
config = dict(channels_config or {})
|
config = dict(channels_config or {})
|
||||||
langgraph_url = _resolve_service_url(config, "langgraph_url", _CHANNELS_LANGGRAPH_URL_ENV, DEFAULT_LANGGRAPH_URL)
|
langgraph_url = _resolve_service_url(config, "langgraph_url", _CHANNELS_LANGGRAPH_URL_ENV, DEFAULT_LANGGRAPH_URL)
|
||||||
gateway_url = _resolve_service_url(config, "gateway_url", _CHANNELS_GATEWAY_URL_ENV, DEFAULT_GATEWAY_URL)
|
gateway_url = _resolve_service_url(config, "gateway_url", _CHANNELS_GATEWAY_URL_ENV, DEFAULT_GATEWAY_URL)
|
||||||
@@ -106,12 +74,10 @@ class ChannelService:
|
|||||||
gateway_url=gateway_url,
|
gateway_url=gateway_url,
|
||||||
default_session=default_session if isinstance(default_session, dict) else None,
|
default_session=default_session if isinstance(default_session, dict) else None,
|
||||||
channel_sessions=channel_sessions,
|
channel_sessions=channel_sessions,
|
||||||
connection_repo=connection_repo,
|
|
||||||
)
|
)
|
||||||
self._channels: dict[str, Any] = {} # name -> Channel instance
|
self._channels: dict[str, Any] = {} # name -> Channel instance
|
||||||
self._config = config
|
self._config = config
|
||||||
self._running = False
|
self._running = False
|
||||||
self._readiness_locks: dict[str, asyncio.Lock] = {}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_app_config(cls, app_config: AppConfig | None = None) -> ChannelService:
|
def from_app_config(cls, app_config: AppConfig | None = None) -> ChannelService:
|
||||||
@@ -124,9 +90,8 @@ class ChannelService:
|
|||||||
# extra fields are allowed by AppConfig (extra="allow")
|
# extra fields are allowed by AppConfig (extra="allow")
|
||||||
extra = app_config.model_extra or {}
|
extra = app_config.model_extra or {}
|
||||||
if "channels" in extra:
|
if "channels" in extra:
|
||||||
channels_config = dict(extra["channels"] or {})
|
channels_config = extra["channels"]
|
||||||
_merge_channel_connection_runtime_config(channels_config, app_config)
|
return cls(channels_config=channels_config)
|
||||||
return cls(channels_config=channels_config, connection_repo=_make_connection_repo(app_config))
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the manager and all enabled channels."""
|
"""Start the manager and all enabled channels."""
|
||||||
@@ -134,169 +99,63 @@ class ChannelService:
|
|||||||
return
|
return
|
||||||
|
|
||||||
await self.manager.start()
|
await self.manager.start()
|
||||||
self._running = True
|
|
||||||
|
|
||||||
ready_status = await self.ensure_ready_channels(attempts=2)
|
|
||||||
ready_count = sum(1 for ready in ready_status.values() if ready)
|
|
||||||
logger.info("ChannelService started with %d/%d ready channels", ready_count, len(ready_status))
|
|
||||||
|
|
||||||
async def ensure_ready_channels(self, *, attempts: int = 1) -> dict[str, bool]:
|
|
||||||
"""Start or restart enabled configured channels that are not ready."""
|
|
||||||
ready_status: dict[str, bool] = {}
|
|
||||||
for name, channel_config in self._config.items():
|
for name, channel_config in self._config.items():
|
||||||
if not isinstance(channel_config, dict):
|
if not isinstance(channel_config, dict):
|
||||||
continue
|
continue
|
||||||
if not channel_config.get("enabled", False):
|
if not channel_config.get("enabled", False):
|
||||||
if _channel_has_credentials(name, channel_config):
|
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
|
||||||
|
has_creds = any(not isinstance(channel_config.get(k), bool) and channel_config.get(k) is not None and str(channel_config[k]).strip() for k in cred_keys)
|
||||||
|
if has_creds:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"A configured channel has credentials configured but is disabled. Set enabled: true under its channels entry in config.yaml to activate it.",
|
"Channel '%s' has credentials configured but is disabled. Set enabled: true under channels.%s in config.yaml to activate it.",
|
||||||
|
name,
|
||||||
|
name,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("A configured channel is disabled, skipping")
|
logger.info("Channel %s is disabled, skipping", name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ready_status[name] = await self.ensure_channel_ready(name, attempts=attempts)
|
await self._start_channel(name, channel_config)
|
||||||
return ready_status
|
|
||||||
|
|
||||||
async def ensure_channel_ready(
|
self._running = True
|
||||||
self,
|
logger.info("ChannelService started with channels: %s", list(self._channels.keys()))
|
||||||
name: str,
|
|
||||||
config: dict[str, Any] | None = None,
|
|
||||||
*,
|
|
||||||
attempts: int = 1,
|
|
||||||
) -> bool:
|
|
||||||
"""Ensure a single enabled channel is running using its current config."""
|
|
||||||
if not self._running:
|
|
||||||
logger.warning("ChannelService is not running; cannot ensure channel readiness")
|
|
||||||
return False
|
|
||||||
|
|
||||||
if config is not None:
|
|
||||||
self._config[name] = dict(config)
|
|
||||||
|
|
||||||
# Serialize per channel: readiness is polled from request handlers, so
|
|
||||||
# concurrent calls must not stop/start the same channel worker twice.
|
|
||||||
lock = self._readiness_locks.setdefault(name, asyncio.Lock())
|
|
||||||
async with lock:
|
|
||||||
channel_config = self._config.get(name)
|
|
||||||
if not channel_config or not isinstance(channel_config, dict):
|
|
||||||
logger.warning("No config for requested channel")
|
|
||||||
return False
|
|
||||||
if not channel_config.get("enabled", False):
|
|
||||||
return False
|
|
||||||
|
|
||||||
channel = self._channels.get(name)
|
|
||||||
if channel is not None and channel.is_running:
|
|
||||||
return True
|
|
||||||
|
|
||||||
if channel is not None:
|
|
||||||
try:
|
|
||||||
await channel.stop()
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Error stopping non-running channel before readiness retry")
|
|
||||||
self._channels.pop(name, None)
|
|
||||||
|
|
||||||
max_attempts = max(1, attempts)
|
|
||||||
for attempt in range(max_attempts):
|
|
||||||
if attempt > 0:
|
|
||||||
logger.info("Retrying channel startup after readiness check")
|
|
||||||
if await self._start_channel(name, channel_config):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop all channels and the manager."""
|
"""Stop all channels and the manager."""
|
||||||
for name, channel in list(self._channels.items()):
|
for name, channel in list(self._channels.items()):
|
||||||
try:
|
try:
|
||||||
await channel.stop()
|
await channel.stop()
|
||||||
logger.info("Channel stopped")
|
logger.info("Channel %s stopped", name)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error stopping channel")
|
logger.exception("Error stopping channel %s", name)
|
||||||
self._channels.clear()
|
self._channels.clear()
|
||||||
|
|
||||||
await self.manager.stop()
|
await self.manager.stop()
|
||||||
self._running = False
|
self._running = False
|
||||||
logger.info("ChannelService stopped")
|
logger.info("ChannelService stopped")
|
||||||
|
|
||||||
def _load_channel_config(self, name: str) -> dict[str, Any] | None:
|
async def restart_channel(self, name: str) -> bool:
|
||||||
"""Load the latest config for a specific channel from disk.
|
|
||||||
|
|
||||||
Uses ``get_app_config()`` which detects file changes via mtime,
|
|
||||||
so edits to ``config.yaml`` are picked up without a process restart.
|
|
||||||
The UI runtime-config overlay applied at startup is re-applied here
|
|
||||||
so a file-driven reload neither drops credentials entered from the
|
|
||||||
browser nor resurrects a channel disconnected from it.
|
|
||||||
Falls back to the cached ``self._config`` when config loading fails.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from deerflow.config.app_config import get_app_config
|
|
||||||
|
|
||||||
app_config = get_app_config()
|
|
||||||
extra = app_config.model_extra or {}
|
|
||||||
channels_config = dict(extra.get("channels") or {})
|
|
||||||
_merge_channel_connection_runtime_config(channels_config, app_config)
|
|
||||||
channel_config = channels_config.get(name)
|
|
||||||
if isinstance(channel_config, dict):
|
|
||||||
# Update the cached config so get_status() stays consistent.
|
|
||||||
self._config[name] = channel_config
|
|
||||||
return channel_config
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to reload config for channel %s, using cached version", name)
|
|
||||||
return self._config.get(name)
|
|
||||||
|
|
||||||
async def restart_channel(self, name: str, *, reload_config: bool = True) -> bool:
|
|
||||||
"""Restart a specific channel. Returns True if successful."""
|
"""Restart a specific channel. Returns True if successful."""
|
||||||
if name in self._channels:
|
if name in self._channels:
|
||||||
try:
|
try:
|
||||||
await self._channels[name].stop()
|
await self._channels[name].stop()
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error stopping channel for restart")
|
logger.exception("Error stopping channel %s for restart", name)
|
||||||
del self._channels[name]
|
del self._channels[name]
|
||||||
|
|
||||||
if reload_config:
|
|
||||||
# Reading config.yaml and the runtime store is disk IO; keep it
|
|
||||||
# off the event loop.
|
|
||||||
config = await asyncio.to_thread(self._load_channel_config, name)
|
|
||||||
else:
|
|
||||||
config = self._config.get(name)
|
config = self._config.get(name)
|
||||||
if not config or not isinstance(config, dict):
|
if not config or not isinstance(config, dict):
|
||||||
logger.warning("No config for requested channel")
|
logger.warning("No config for channel %s", name)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not config.get("enabled", False):
|
|
||||||
logger.info("Channel %s is disabled, skipping restart", name)
|
|
||||||
return True
|
|
||||||
|
|
||||||
return await self._start_channel(name, config)
|
return await self._start_channel(name, config)
|
||||||
|
|
||||||
async def configure_channel(self, name: str, config: dict[str, Any]) -> bool:
|
|
||||||
"""Apply runtime config for a channel and restart it if the service is running."""
|
|
||||||
self._config[name] = dict(config)
|
|
||||||
if not self._running:
|
|
||||||
return True
|
|
||||||
# The caller just supplied the authoritative config (e.g. credentials
|
|
||||||
# entered in the browser that are never written to config.yaml) — a
|
|
||||||
# file reload here would clobber it with the stale on-disk entry.
|
|
||||||
return await self.restart_channel(name, reload_config=False)
|
|
||||||
|
|
||||||
async def remove_channel(self, name: str) -> bool:
|
|
||||||
"""Remove runtime config for a channel and stop it if currently running."""
|
|
||||||
self._config.pop(name, None)
|
|
||||||
channel = self._channels.pop(name, None)
|
|
||||||
if channel is None:
|
|
||||||
return True
|
|
||||||
try:
|
|
||||||
await channel.stop()
|
|
||||||
logger.info("Channel stopped and removed")
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Error stopping channel for removal")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _start_channel(self, name: str, config: dict[str, Any]) -> bool:
|
async def _start_channel(self, name: str, config: dict[str, Any]) -> bool:
|
||||||
"""Instantiate and start a single channel."""
|
"""Instantiate and start a single channel."""
|
||||||
import_path = _CHANNEL_REGISTRY.get(name)
|
import_path = _CHANNEL_REGISTRY.get(name)
|
||||||
if not import_path:
|
if not import_path:
|
||||||
logger.warning("Unknown channel type")
|
logger.warning("Unknown channel type: %s", name)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -304,26 +163,24 @@ class ChannelService:
|
|||||||
|
|
||||||
channel_cls = resolve_class(import_path, base_class=None)
|
channel_cls = resolve_class(import_path, base_class=None)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to import channel class")
|
logger.exception("Failed to import channel class for %s", name)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = dict(config)
|
config = dict(config)
|
||||||
config["channel_store"] = self.store
|
config["channel_store"] = self.store
|
||||||
if self._connection_repo is not None:
|
|
||||||
config["connection_repo"] = self._connection_repo
|
|
||||||
channel = channel_cls(bus=self.bus, config=config)
|
channel = channel_cls(bus=self.bus, config=config)
|
||||||
self._channels[name] = channel
|
self._channels[name] = channel
|
||||||
await channel.start()
|
await channel.start()
|
||||||
if not channel.is_running:
|
if not channel.is_running:
|
||||||
self._channels.pop(name, None)
|
self._channels.pop(name, None)
|
||||||
logger.error("Channel did not enter a running state after start()")
|
logger.error("Channel %s did not enter a running state after start()", name)
|
||||||
return False
|
return False
|
||||||
logger.info("Channel started")
|
logger.info("Channel %s started", name)
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
self._channels.pop(name, None)
|
self._channels.pop(name, None)
|
||||||
logger.exception("Failed to start channel")
|
logger.exception("Failed to start channel %s", name)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_status(self) -> dict[str, Any]:
|
def get_status(self) -> dict[str, Any]:
|
||||||
@@ -362,9 +219,7 @@ async def start_channel_service(app_config: AppConfig | None = None) -> ChannelS
|
|||||||
global _channel_service
|
global _channel_service
|
||||||
if _channel_service is not None:
|
if _channel_service is not None:
|
||||||
return _channel_service
|
return _channel_service
|
||||||
# from_app_config reads the JSON channel store and runtime config files;
|
_channel_service = ChannelService.from_app_config(app_config)
|
||||||
# keep that disk IO off the event loop.
|
|
||||||
_channel_service = await asyncio.to_thread(ChannelService.from_app_config, app_config)
|
|
||||||
await _channel_service.start()
|
await _channel_service.start()
|
||||||
return _channel_service
|
return _channel_service
|
||||||
|
|
||||||
|
|||||||
+14
-172
@@ -9,8 +9,6 @@ from typing import Any
|
|||||||
from markdown_to_mrkdwn import SlackMarkdownConverter
|
from markdown_to_mrkdwn import SlackMarkdownConverter
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
|
||||||
from app.channels.connection_identity import attach_connection_identity
|
|
||||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -34,20 +32,6 @@ def _normalize_allowed_users(allowed_users: Any) -> set[str]:
|
|||||||
return {str(user_id) for user_id in values if str(user_id)}
|
return {str(user_id) for user_id in values if str(user_id)}
|
||||||
|
|
||||||
|
|
||||||
def _strip_leading_slack_bot_mention(text: str, bot_user_id: str | None) -> str:
|
|
||||||
if not bot_user_id:
|
|
||||||
return text
|
|
||||||
if not text.startswith("<@"):
|
|
||||||
return text
|
|
||||||
end = text.find(">")
|
|
||||||
if end <= 2:
|
|
||||||
return text
|
|
||||||
mentioned_user_id = text[2:end].split("|", 1)[0].lstrip("!")
|
|
||||||
if mentioned_user_id != bot_user_id:
|
|
||||||
return text
|
|
||||||
return text[end + 1 :].lstrip()
|
|
||||||
|
|
||||||
|
|
||||||
class SlackChannel(Channel):
|
class SlackChannel(Channel):
|
||||||
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
|
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
|
||||||
|
|
||||||
@@ -65,11 +49,6 @@ class SlackChannel(Channel):
|
|||||||
self._web_client = None
|
self._web_client = None
|
||||||
self._loop: asyncio.AbstractEventLoop | None = None
|
self._loop: asyncio.AbstractEventLoop | None = None
|
||||||
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
|
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
|
||||||
self._connection_repo = config.get("connection_repo")
|
|
||||||
self._web_client_factory = config.get("web_client_factory")
|
|
||||||
self._connection_web_clients: dict[str, tuple[str, Any]] = {}
|
|
||||||
configured_bot_user_id = config.get("bot_user_id")
|
|
||||||
self._bot_user_id = str(configured_bot_user_id).lstrip("@") if configured_bot_user_id else None
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self._running:
|
if self._running:
|
||||||
@@ -84,28 +63,15 @@ class SlackChannel(Channel):
|
|||||||
return
|
return
|
||||||
|
|
||||||
self._SocketModeResponse = SocketModeResponse
|
self._SocketModeResponse = SocketModeResponse
|
||||||
if self._web_client_factory is None:
|
|
||||||
self._web_client_factory = WebClient
|
|
||||||
|
|
||||||
bot_token = self.config.get("bot_token", "")
|
bot_token = self.config.get("bot_token", "")
|
||||||
app_token = self.config.get("app_token", "")
|
app_token = self.config.get("app_token", "")
|
||||||
|
|
||||||
if self._connection_repo is not None and self.config.get("event_delivery") == "http":
|
|
||||||
if not bot_token:
|
|
||||||
logger.error("Slack HTTP Events mode requires bot_token")
|
|
||||||
return
|
|
||||||
await self._initialize_operator_web_client(str(bot_token))
|
|
||||||
self._loop = asyncio.get_event_loop()
|
|
||||||
self._running = True
|
|
||||||
self.bus.subscribe_outbound(self._on_outbound)
|
|
||||||
logger.info("Slack channel started in HTTP Events mode")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not bot_token or not app_token:
|
if not bot_token or not app_token:
|
||||||
logger.error("Slack channel requires bot_token and app_token")
|
logger.error("Slack channel requires bot_token and app_token")
|
||||||
return
|
return
|
||||||
|
|
||||||
await self._initialize_operator_web_client(str(bot_token))
|
self._web_client = WebClient(token=bot_token)
|
||||||
self._socket_client = SocketModeClient(
|
self._socket_client = SocketModeClient(
|
||||||
app_token=app_token,
|
app_token=app_token,
|
||||||
web_client=self._web_client,
|
web_client=self._web_client,
|
||||||
@@ -130,8 +96,7 @@ class SlackChannel(Channel):
|
|||||||
logger.info("Slack channel stopped")
|
logger.info("Slack channel stopped")
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||||||
web_client = await self._get_web_client_for_message(msg)
|
if not self._web_client:
|
||||||
if not web_client:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
@@ -144,12 +109,11 @@ class SlackChannel(Channel):
|
|||||||
last_exc: Exception | None = None
|
last_exc: Exception | None = None
|
||||||
for attempt in range(_max_retries):
|
for attempt in range(_max_retries):
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(web_client.chat_postMessage, **kwargs)
|
await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs)
|
||||||
# Add a completion reaction to the thread root
|
# Add a completion reaction to the thread root
|
||||||
if msg.thread_ts:
|
if msg.thread_ts:
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(
|
||||||
self._add_reaction_with_client,
|
self._add_reaction,
|
||||||
web_client,
|
|
||||||
msg.chat_id,
|
msg.chat_id,
|
||||||
msg.thread_ts,
|
msg.thread_ts,
|
||||||
"white_check_mark",
|
"white_check_mark",
|
||||||
@@ -173,8 +137,7 @@ class SlackChannel(Channel):
|
|||||||
if msg.thread_ts:
|
if msg.thread_ts:
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(
|
||||||
self._add_reaction_with_client,
|
self._add_reaction,
|
||||||
web_client,
|
|
||||||
msg.chat_id,
|
msg.chat_id,
|
||||||
msg.thread_ts,
|
msg.thread_ts,
|
||||||
"x",
|
"x",
|
||||||
@@ -186,8 +149,7 @@ class SlackChannel(Channel):
|
|||||||
raise last_exc
|
raise last_exc
|
||||||
|
|
||||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||||
web_client = await self._get_web_client_for_message(msg)
|
if not self._web_client:
|
||||||
if not web_client:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -200,7 +162,7 @@ class SlackChannel(Channel):
|
|||||||
if msg.thread_ts:
|
if msg.thread_ts:
|
||||||
kwargs["thread_ts"] = msg.thread_ts
|
kwargs["thread_ts"] = msg.thread_ts
|
||||||
|
|
||||||
await asyncio.to_thread(web_client.files_upload_v2, **kwargs)
|
await asyncio.to_thread(self._web_client.files_upload_v2, **kwargs)
|
||||||
logger.info("[Slack] file uploaded: %s to channel=%s", attachment.filename, msg.chat_id)
|
logger.info("[Slack] file uploaded: %s to channel=%s", attachment.filename, msg.chat_id)
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -209,45 +171,12 @@ class SlackChannel(Channel):
|
|||||||
|
|
||||||
# -- internal ----------------------------------------------------------
|
# -- internal ----------------------------------------------------------
|
||||||
|
|
||||||
async def _initialize_operator_web_client(self, bot_token: str) -> None:
|
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
|
||||||
self._web_client = self._web_client_factory(token=bot_token)
|
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
|
||||||
if self._bot_user_id is not None:
|
if not self._web_client:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
auth_info = await asyncio.to_thread(self._web_client.auth_test)
|
self._web_client.reactions_add(
|
||||||
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
|
|
||||||
if user_id is None:
|
|
||||||
auth_get = getattr(auth_info, "get", None)
|
|
||||||
user_id = auth_get("user_id") if callable(auth_get) else None
|
|
||||||
if isinstance(user_id, str) and user_id:
|
|
||||||
self._bot_user_id = user_id
|
|
||||||
except Exception:
|
|
||||||
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
|
|
||||||
|
|
||||||
async def _get_web_client_for_message(self, msg: OutboundMessage):
|
|
||||||
if msg.connection_id and self._connection_repo is not None:
|
|
||||||
credentials = await self._connection_repo.get_credentials(msg.connection_id)
|
|
||||||
access_token = credentials.get("access_token") if credentials else None
|
|
||||||
if not access_token:
|
|
||||||
return self._web_client
|
|
||||||
# WebClient keeps its own HTTP session and rate-limit state, so
|
|
||||||
# reuse one per connection until its token changes.
|
|
||||||
cached = self._connection_web_clients.get(msg.connection_id)
|
|
||||||
if cached is not None and cached[0] == access_token:
|
|
||||||
return cached[1]
|
|
||||||
if self._web_client_factory is None:
|
|
||||||
from slack_sdk import WebClient
|
|
||||||
|
|
||||||
self._web_client_factory = WebClient
|
|
||||||
web_client = self._web_client_factory(token=access_token)
|
|
||||||
self._connection_web_clients[msg.connection_id] = (access_token, web_client)
|
|
||||||
return web_client
|
|
||||||
return self._web_client
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _add_reaction_with_client(web_client, channel_id: str, timestamp: str, emoji: str) -> None:
|
|
||||||
try:
|
|
||||||
web_client.reactions_add(
|
|
||||||
channel=channel_id,
|
channel=channel_id,
|
||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
name=emoji,
|
name=emoji,
|
||||||
@@ -256,12 +185,6 @@ class SlackChannel(Channel):
|
|||||||
if "already_reacted" not in str(exc):
|
if "already_reacted" not in str(exc):
|
||||||
logger.warning("[Slack] failed to add reaction %s: %s", emoji, exc)
|
logger.warning("[Slack] failed to add reaction %s: %s", emoji, exc)
|
||||||
|
|
||||||
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
|
|
||||||
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
|
|
||||||
if not self._web_client:
|
|
||||||
return
|
|
||||||
self._add_reaction_with_client(self._web_client, channel_id, timestamp, emoji)
|
|
||||||
|
|
||||||
def _send_running_reply(self, channel_id: str, thread_ts: str) -> None:
|
def _send_running_reply(self, channel_id: str, thread_ts: str) -> None:
|
||||||
"""Send a 'Working on it......' reply in the thread (called from SDK thread)."""
|
"""Send a 'Working on it......' reply in the thread (called from SDK thread)."""
|
||||||
if not self._web_client:
|
if not self._web_client:
|
||||||
@@ -287,26 +210,17 @@ class SlackChannel(Channel):
|
|||||||
if event_type != "events_api":
|
if event_type != "events_api":
|
||||||
return
|
return
|
||||||
|
|
||||||
if self._bot_user_id is None:
|
|
||||||
authorization = next((item for item in req.payload.get("authorizations", []) if isinstance(item, dict)), None)
|
|
||||||
user_id = authorization.get("user_id") if authorization else None
|
|
||||||
if isinstance(user_id, str) and user_id:
|
|
||||||
self._bot_user_id = user_id
|
|
||||||
|
|
||||||
event = req.payload.get("event", {})
|
event = req.payload.get("event", {})
|
||||||
etype = event.get("type", "")
|
etype = event.get("type", "")
|
||||||
|
|
||||||
# Handle message events (DM or @mention)
|
# Handle message events (DM or @mention)
|
||||||
if etype in ("message", "app_mention"):
|
if etype in ("message", "app_mention"):
|
||||||
self._handle_message_event(
|
self._handle_message_event(event)
|
||||||
event,
|
|
||||||
team_id=req.payload.get("team_id") or req.payload.get("team") or event.get("team"),
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error processing Slack event")
|
logger.exception("Error processing Slack event")
|
||||||
|
|
||||||
def _handle_message_event(self, event: dict, *, team_id: str | None = None) -> None:
|
def _handle_message_event(self, event: dict) -> None:
|
||||||
# Ignore bot messages
|
# Ignore bot messages
|
||||||
if event.get("bot_id") or event.get("subtype"):
|
if event.get("bot_id") or event.get("subtype"):
|
||||||
return
|
return
|
||||||
@@ -319,28 +233,13 @@ class SlackChannel(Channel):
|
|||||||
return
|
return
|
||||||
|
|
||||||
text = event.get("text", "").strip()
|
text = event.get("text", "").strip()
|
||||||
if event.get("type") == "app_mention":
|
|
||||||
text = _strip_leading_slack_bot_mention(text, self._bot_user_id)
|
|
||||||
if not text:
|
if not text:
|
||||||
return
|
return
|
||||||
|
|
||||||
connect_code = extract_connect_code(text)
|
|
||||||
if connect_code:
|
|
||||||
if self._loop and self._loop.is_running():
|
|
||||||
asyncio.run_coroutine_threadsafe(
|
|
||||||
self._bind_connection_from_connect_code(
|
|
||||||
event=event,
|
|
||||||
team_id=str(team_id or event.get("team") or ""),
|
|
||||||
code=connect_code,
|
|
||||||
),
|
|
||||||
self._loop,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
channel_id = event.get("channel", "")
|
channel_id = event.get("channel", "")
|
||||||
thread_ts = event.get("thread_ts") or event.get("ts", "")
|
thread_ts = event.get("thread_ts") or event.get("ts", "")
|
||||||
|
|
||||||
if is_known_channel_command(text):
|
if text.startswith("/"):
|
||||||
msg_type = InboundMessageType.COMMAND
|
msg_type = InboundMessageType.COMMAND
|
||||||
else:
|
else:
|
||||||
msg_type = InboundMessageType.CHAT
|
msg_type = InboundMessageType.CHAT
|
||||||
@@ -362,61 +261,4 @@ class SlackChannel(Channel):
|
|||||||
self._add_reaction(channel_id, event.get("ts", thread_ts), "eyes")
|
self._add_reaction(channel_id, event.get("ts", thread_ts), "eyes")
|
||||||
# Send "running" reply first (fire-and-forget from SDK thread)
|
# Send "running" reply first (fire-and-forget from SDK thread)
|
||||||
self._send_running_reply(channel_id, thread_ts)
|
self._send_running_reply(channel_id, thread_ts)
|
||||||
if self._connection_repo is None:
|
|
||||||
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
|
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
|
||||||
else:
|
|
||||||
asyncio.run_coroutine_threadsafe(self._publish_inbound_with_connection(inbound, team_id=team_id), self._loop)
|
|
||||||
|
|
||||||
async def _publish_inbound_with_connection(self, inbound, *, team_id: str | None = None) -> None:
|
|
||||||
inbound = await self._attach_connection_identity(inbound, team_id=team_id)
|
|
||||||
await self.bus.publish_inbound(inbound)
|
|
||||||
|
|
||||||
async def _attach_connection_identity(self, inbound, *, team_id: str | None = None):
|
|
||||||
workspace_id = str(team_id or inbound.metadata.get("team_id") or "")
|
|
||||||
return await attach_connection_identity(
|
|
||||||
inbound,
|
|
||||||
repo=self._connection_repo,
|
|
||||||
provider="slack",
|
|
||||||
workspace_id=workspace_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _bind_connection_from_connect_code(self, *, event: dict, team_id: str, code: str) -> bool:
|
|
||||||
if self._connection_repo is None or not code:
|
|
||||||
return False
|
|
||||||
|
|
||||||
channel_id = str(event.get("channel") or "")
|
|
||||||
thread_ts = str(event.get("thread_ts") or event.get("ts") or "")
|
|
||||||
state = await self._connection_repo.consume_oauth_state(provider="slack", state=code)
|
|
||||||
if state is None:
|
|
||||||
await self._post_connection_reply(channel_id, "Slack connection code is invalid or expired.", thread_ts)
|
|
||||||
return True
|
|
||||||
|
|
||||||
user_id = str(event.get("user") or "")
|
|
||||||
if not user_id or not team_id:
|
|
||||||
await self._post_connection_reply(channel_id, "Slack connection could not be completed from this message.", thread_ts)
|
|
||||||
return True
|
|
||||||
|
|
||||||
await self._connection_repo.upsert_connection(
|
|
||||||
owner_user_id=state["owner_user_id"],
|
|
||||||
provider="slack",
|
|
||||||
external_account_id=user_id,
|
|
||||||
workspace_id=team_id,
|
|
||||||
metadata={
|
|
||||||
"team_id": team_id,
|
|
||||||
"channel_id": channel_id,
|
|
||||||
},
|
|
||||||
status="connected",
|
|
||||||
)
|
|
||||||
await self._post_connection_reply(channel_id, "Slack connected to DeerFlow.", thread_ts)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _post_connection_reply(self, channel_id: str, text: str, thread_ts: str | None = None) -> None:
|
|
||||||
if not self._web_client or not channel_id:
|
|
||||||
return
|
|
||||||
kwargs: dict[str, Any] = {"channel": channel_id, "text": text}
|
|
||||||
if thread_ts:
|
|
||||||
kwargs["thread_ts"] = thread_ts
|
|
||||||
try:
|
|
||||||
await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[Slack] failed to send connection reply in channel=%s", channel_id)
|
|
||||||
|
|||||||
@@ -5,27 +5,13 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.connection_identity import attach_connection_identity
|
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
TELEGRAM_MAX_MESSAGE_LENGTH = 4096
|
|
||||||
STREAM_EDIT_MIN_INTERVAL_SECONDS = 1.0
|
|
||||||
# Groups (negative chat_id) are capped at 20 messages/minute by Telegram,
|
|
||||||
# so stream edits there must pace well below the private-chat 1 msg/s guideline.
|
|
||||||
STREAM_EDIT_GROUP_MIN_INTERVAL_SECONDS = 3.0
|
|
||||||
# Bound on tracked in-flight streamed messages; entries normally clear on the
|
|
||||||
# final update, this only guards against leaks when a final never arrives.
|
|
||||||
MAX_TRACKED_STREAM_MESSAGES = 256
|
|
||||||
|
|
||||||
# Indirection so tests can patch the clock without touching the global time module.
|
|
||||||
_monotonic = time.monotonic
|
|
||||||
|
|
||||||
|
|
||||||
class TelegramChannel(Channel):
|
class TelegramChannel(Channel):
|
||||||
"""Telegram bot channel using long-polling.
|
"""Telegram bot channel using long-polling.
|
||||||
@@ -49,14 +35,6 @@ class TelegramChannel(Channel):
|
|||||||
pass
|
pass
|
||||||
# chat_id -> last sent message_id for threaded replies
|
# chat_id -> last sent message_id for threaded replies
|
||||||
self._last_bot_message: dict[str, int] = {}
|
self._last_bot_message: dict[str, int] = {}
|
||||||
# stream_key ("chat_id:thread_ts") -> state of the in-flight streamed
|
|
||||||
# bot message being edited in place: {"message_id", "last_edit_at", "last_text"}
|
|
||||||
self._stream_messages: dict[str, dict[str, Any]] = {}
|
|
||||||
self._connection_repo = config.get("connection_repo")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_streaming(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self._running:
|
if self._running:
|
||||||
@@ -82,17 +60,12 @@ class TelegramChannel(Channel):
|
|||||||
|
|
||||||
# Command handlers
|
# Command handlers
|
||||||
app.add_handler(CommandHandler("start", self._cmd_start))
|
app.add_handler(CommandHandler("start", self._cmd_start))
|
||||||
app.add_handler(CommandHandler("bootstrap", self._cmd_generic))
|
|
||||||
app.add_handler(CommandHandler("new", self._cmd_generic))
|
app.add_handler(CommandHandler("new", self._cmd_generic))
|
||||||
app.add_handler(CommandHandler("status", self._cmd_generic))
|
app.add_handler(CommandHandler("status", self._cmd_generic))
|
||||||
app.add_handler(CommandHandler("models", self._cmd_generic))
|
app.add_handler(CommandHandler("models", self._cmd_generic))
|
||||||
app.add_handler(CommandHandler("memory", self._cmd_generic))
|
app.add_handler(CommandHandler("memory", self._cmd_generic))
|
||||||
app.add_handler(CommandHandler("help", self._cmd_generic))
|
app.add_handler(CommandHandler("help", self._cmd_generic))
|
||||||
|
|
||||||
# Slash skill commands are dynamic and cannot all be pre-registered
|
|
||||||
# with Telegram, so route unknown slash commands through chat handling.
|
|
||||||
app.add_handler(MessageHandler(filters.TEXT & filters.COMMAND, self._on_text))
|
|
||||||
|
|
||||||
# General message handler
|
# General message handler
|
||||||
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self._on_text))
|
app.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND, self._on_text))
|
||||||
|
|
||||||
@@ -124,117 +97,10 @@ class TelegramChannel(Channel):
|
|||||||
logger.error("Invalid Telegram chat_id: %s", msg.chat_id)
|
logger.error("Invalid Telegram chat_id: %s", msg.chat_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
key = self._stream_key(msg.chat_id, msg.thread_ts)
|
kwargs: dict[str, Any] = {"chat_id": chat_id, "text": msg.text}
|
||||||
|
|
||||||
if not msg.is_final:
|
|
||||||
await self._send_stream_update(chat_id, key, msg.text, reply_to=self._parse_message_id(msg.thread_ts))
|
|
||||||
return
|
|
||||||
|
|
||||||
state = self._stream_messages.pop(key, None)
|
|
||||||
if state is not None:
|
|
||||||
await self._finalize_stream_message(chat_id, msg.chat_id, state, msg.text)
|
|
||||||
return
|
|
||||||
|
|
||||||
await self._send_new_message(chat_id, msg.chat_id, msg.text, _max_retries=_max_retries)
|
|
||||||
|
|
||||||
async def _send_stream_update(self, chat_id: int, key: str, text: str, reply_to: int | None = None) -> None:
|
|
||||||
"""Edit the in-flight streamed message with accumulated text.
|
|
||||||
|
|
||||||
Updates are best-effort: throttled, rate-limit drops are silent. The
|
|
||||||
manager always publishes a final message afterwards, which guarantees
|
|
||||||
delivery of the complete text.
|
|
||||||
"""
|
|
||||||
if not text:
|
|
||||||
return
|
|
||||||
|
|
||||||
display = text
|
|
||||||
if len(display) > TELEGRAM_MAX_MESSAGE_LENGTH:
|
|
||||||
display = display[: TELEGRAM_MAX_MESSAGE_LENGTH - 1] + "…"
|
|
||||||
|
|
||||||
bot = self._application.bot
|
|
||||||
state = self._stream_messages.get(key)
|
|
||||||
|
|
||||||
send_kwargs: dict[str, Any] = {"chat_id": chat_id, "text": display}
|
|
||||||
if reply_to:
|
|
||||||
send_kwargs["reply_to_message_id"] = reply_to
|
|
||||||
|
|
||||||
if state is None:
|
|
||||||
try:
|
|
||||||
sent = await bot.send_message(**send_kwargs)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[Telegram] failed to start stream message in chat=%s", chat_id)
|
|
||||||
return
|
|
||||||
self._register_stream_message(key, message_id=sent.message_id, last_text=display, last_edit_at=_monotonic())
|
|
||||||
return
|
|
||||||
|
|
||||||
now = _monotonic()
|
|
||||||
min_interval = STREAM_EDIT_GROUP_MIN_INTERVAL_SECONDS if chat_id < 0 else STREAM_EDIT_MIN_INTERVAL_SECONDS
|
|
||||||
if now - state["last_edit_at"] < min_interval:
|
|
||||||
return
|
|
||||||
if display == state["last_text"]:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
await bot.edit_message_text(chat_id=chat_id, message_id=state["message_id"], text=display)
|
|
||||||
except Exception as exc:
|
|
||||||
if self._is_not_modified(exc):
|
|
||||||
state["last_text"] = display
|
|
||||||
return
|
|
||||||
if self._is_retry_after(exc):
|
|
||||||
logger.debug("[Telegram] stream edit rate-limited in chat=%s, dropping update", chat_id)
|
|
||||||
return
|
|
||||||
logger.warning("[Telegram] stream edit failed in chat=%s, sending new message: %s", chat_id, exc)
|
|
||||||
try:
|
|
||||||
sent = await bot.send_message(**send_kwargs)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[Telegram] failed to send fallback stream message in chat=%s", chat_id)
|
|
||||||
return
|
|
||||||
state["message_id"] = sent.message_id
|
|
||||||
|
|
||||||
state["last_edit_at"] = _monotonic()
|
|
||||||
state["last_text"] = display
|
|
||||||
|
|
||||||
async def _finalize_stream_message(self, chat_id: int, chat_key: str, state: dict[str, Any], text: str) -> None:
|
|
||||||
"""Apply the final text: edit the streamed message, splitting overflow into follow-ups."""
|
|
||||||
bot = self._application.bot
|
|
||||||
chunks = self._split_message(text or "")
|
|
||||||
|
|
||||||
edited = True
|
|
||||||
if chunks[0] != state["last_text"]:
|
|
||||||
edited = await self._edit_final_chunk(bot, chat_id, state["message_id"], chunks[0])
|
|
||||||
|
|
||||||
if edited:
|
|
||||||
self._last_bot_message[chat_key] = state["message_id"]
|
|
||||||
else:
|
|
||||||
# Edit could not be applied (e.g. message deleted) — deliver the
|
|
||||||
# first chunk as a fresh message with the standard retry policy.
|
|
||||||
await self._send_new_message(chat_id, chat_key, chunks[0])
|
|
||||||
|
|
||||||
for chunk in chunks[1:]:
|
|
||||||
await self._send_new_message(chat_id, chat_key, chunk)
|
|
||||||
|
|
||||||
async def _edit_final_chunk(self, bot, chat_id: int, message_id: int, text: str) -> bool:
|
|
||||||
"""Edit with one rate-limit retry. Returns False if the edit could not be applied."""
|
|
||||||
for attempt in range(2):
|
|
||||||
try:
|
|
||||||
await bot.edit_message_text(chat_id=chat_id, message_id=message_id, text=text)
|
|
||||||
return True
|
|
||||||
except Exception as exc:
|
|
||||||
if self._is_not_modified(exc):
|
|
||||||
return True
|
|
||||||
if self._is_retry_after(exc) and attempt == 0:
|
|
||||||
await asyncio.sleep(self._retry_after_seconds(exc))
|
|
||||||
continue
|
|
||||||
logger.warning("[Telegram] final edit failed in chat=%s: %s", chat_id, exc)
|
|
||||||
return False
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _send_new_message(self, chat_id: int, chat_key: str, text: str, *, _max_retries: int = 3) -> int | None:
|
|
||||||
"""Send a fresh message with retry/backoff. Returns the sent message_id."""
|
|
||||||
kwargs: dict[str, Any] = {"chat_id": chat_id, "text": text}
|
|
||||||
|
|
||||||
# Reply to the last bot message in this chat for threading
|
# Reply to the last bot message in this chat for threading
|
||||||
reply_to = self._last_bot_message.get(chat_key)
|
reply_to = self._last_bot_message.get(msg.chat_id)
|
||||||
if reply_to:
|
if reply_to:
|
||||||
kwargs["reply_to_message_id"] = reply_to
|
kwargs["reply_to_message_id"] = reply_to
|
||||||
|
|
||||||
@@ -243,8 +109,8 @@ class TelegramChannel(Channel):
|
|||||||
for attempt in range(_max_retries):
|
for attempt in range(_max_retries):
|
||||||
try:
|
try:
|
||||||
sent = await bot.send_message(**kwargs)
|
sent = await bot.send_message(**kwargs)
|
||||||
self._last_bot_message[chat_key] = sent.message_id
|
self._last_bot_message[msg.chat_id] = sent.message_id
|
||||||
return sent.message_id
|
return
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
last_exc = exc
|
last_exc = exc
|
||||||
if attempt < _max_retries - 1:
|
if attempt < _max_retries - 1:
|
||||||
@@ -307,63 +173,17 @@ class TelegramChannel(Channel):
|
|||||||
|
|
||||||
# -- helpers -----------------------------------------------------------
|
# -- helpers -----------------------------------------------------------
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _stream_key(chat_id: str, thread_ts: str | None) -> str:
|
|
||||||
return f"{chat_id}:{thread_ts or ''}"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _parse_message_id(value: str | None) -> int | None:
|
|
||||||
try:
|
|
||||||
return int(value) if value else None
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _register_stream_message(self, key: str, *, message_id: int, last_text: str, last_edit_at: float) -> None:
|
|
||||||
self._stream_messages.pop(key, None)
|
|
||||||
while len(self._stream_messages) >= MAX_TRACKED_STREAM_MESSAGES:
|
|
||||||
self._stream_messages.pop(next(iter(self._stream_messages)))
|
|
||||||
self._stream_messages[key] = {
|
|
||||||
"message_id": message_id,
|
|
||||||
"last_edit_at": last_edit_at,
|
|
||||||
"last_text": last_text,
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _is_retry_after(exc: Exception) -> bool:
|
|
||||||
return getattr(exc, "retry_after", None) is not None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _retry_after_seconds(exc: Exception) -> float:
|
|
||||||
value = getattr(exc, "retry_after", 0)
|
|
||||||
if hasattr(value, "total_seconds"):
|
|
||||||
return float(value.total_seconds())
|
|
||||||
return float(value)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _is_not_modified(exc: Exception) -> bool:
|
|
||||||
return "message is not modified" in str(exc).lower()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _split_message(text: str) -> list[str]:
|
|
||||||
return [text[i : i + TELEGRAM_MAX_MESSAGE_LENGTH] for i in range(0, len(text), TELEGRAM_MAX_MESSAGE_LENGTH)] or [text]
|
|
||||||
|
|
||||||
async def _send_running_reply(self, chat_id: str, reply_to_message_id: int) -> None:
|
async def _send_running_reply(self, chat_id: str, reply_to_message_id: int) -> None:
|
||||||
"""Send a 'Working on it...' reply and register it as the stream target."""
|
"""Send a 'Working on it...' reply to the user's message."""
|
||||||
if not self._application:
|
if not self._application:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
bot = self._application.bot
|
bot = self._application.bot
|
||||||
sent = await bot.send_message(
|
await bot.send_message(
|
||||||
chat_id=int(chat_id),
|
chat_id=int(chat_id),
|
||||||
text="Working on it...",
|
text="Working on it...",
|
||||||
reply_to_message_id=reply_to_message_id,
|
reply_to_message_id=reply_to_message_id,
|
||||||
)
|
)
|
||||||
self._register_stream_message(
|
|
||||||
self._stream_key(chat_id, str(reply_to_message_id)),
|
|
||||||
message_id=sent.message_id,
|
|
||||||
last_text="Working on it...",
|
|
||||||
last_edit_at=0.0,
|
|
||||||
)
|
|
||||||
logger.info("[Telegram] 'Working on it...' reply sent in chat=%s", chat_id)
|
logger.info("[Telegram] 'Working on it...' reply sent in chat=%s", chat_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("[Telegram] failed to send running reply in chat=%s", chat_id)
|
logger.exception("[Telegram] failed to send running reply in chat=%s", chat_id)
|
||||||
@@ -408,90 +228,10 @@ class TelegramChannel(Channel):
|
|||||||
return True
|
return True
|
||||||
return user_id in self._allowed_users
|
return user_id in self._allowed_users
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _telegram_display_name(user) -> str:
|
|
||||||
full_name = getattr(user, "full_name", None)
|
|
||||||
if isinstance(full_name, str) and full_name:
|
|
||||||
return full_name
|
|
||||||
username = getattr(user, "username", None)
|
|
||||||
if isinstance(username, str) and username:
|
|
||||||
return username
|
|
||||||
return str(getattr(user, "id", ""))
|
|
||||||
|
|
||||||
async def _bind_connection_from_start_token(self, update, state_token: str) -> bool:
|
|
||||||
if self._connection_repo is None or not state_token:
|
|
||||||
return False
|
|
||||||
|
|
||||||
state = await self._connection_repo.consume_oauth_state(provider="telegram", state=state_token)
|
|
||||||
if state is None:
|
|
||||||
await update.message.reply_text("Telegram connection link is invalid or expired.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
owner_user_id = state["owner_user_id"]
|
|
||||||
user_id = str(update.effective_user.id)
|
|
||||||
chat_id = str(update.effective_chat.id)
|
|
||||||
connection = await self._connection_repo.upsert_connection(
|
|
||||||
owner_user_id=owner_user_id,
|
|
||||||
provider="telegram",
|
|
||||||
external_account_id=user_id,
|
|
||||||
external_account_name=self._telegram_display_name(update.effective_user),
|
|
||||||
workspace_id=chat_id,
|
|
||||||
workspace_name=None,
|
|
||||||
metadata={
|
|
||||||
"chat_id": chat_id,
|
|
||||||
"chat_type": update.effective_chat.type,
|
|
||||||
"telegram_username": getattr(update.effective_user, "username", None),
|
|
||||||
},
|
|
||||||
status="connected",
|
|
||||||
)
|
|
||||||
logger.info("[Telegram] bound chat=%s user=%s to DeerFlow user=%s connection=%s", chat_id, user_id, owner_user_id, connection["id"])
|
|
||||||
await update.message.reply_text("Telegram connected to DeerFlow.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
|
||||||
return await attach_connection_identity(
|
|
||||||
inbound,
|
|
||||||
repo=self._connection_repo,
|
|
||||||
provider="telegram",
|
|
||||||
workspace_id=inbound.chat_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_bot_username(self, context) -> str | None:
|
|
||||||
bot = getattr(context, "bot", None)
|
|
||||||
username = getattr(bot, "username", None)
|
|
||||||
if not username and self._application is not None:
|
|
||||||
username = getattr(getattr(self._application, "bot", None), "username", None)
|
|
||||||
return str(username) if username else None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _strip_bot_username_from_leading_command(text: str, bot_username: str | None) -> str:
|
|
||||||
username = (bot_username or "").lstrip("@").lower()
|
|
||||||
if not username or not text.startswith("/"):
|
|
||||||
return text
|
|
||||||
|
|
||||||
parts = text.split(maxsplit=1)
|
|
||||||
command_token = parts[0]
|
|
||||||
if "@" not in command_token:
|
|
||||||
return text
|
|
||||||
|
|
||||||
command_name, addressed_username = command_token[1:].rsplit("@", 1)
|
|
||||||
if not command_name or addressed_username.lower() != username:
|
|
||||||
return text
|
|
||||||
|
|
||||||
normalized = f"/{command_name}"
|
|
||||||
if len(parts) > 1:
|
|
||||||
normalized = f"{normalized} {parts[1]}"
|
|
||||||
return normalized
|
|
||||||
|
|
||||||
async def _cmd_start(self, update, context) -> None:
|
async def _cmd_start(self, update, context) -> None:
|
||||||
"""Handle /start command."""
|
"""Handle /start command."""
|
||||||
if not self._check_user(update.effective_user.id):
|
if not self._check_user(update.effective_user.id):
|
||||||
return
|
return
|
||||||
args = getattr(context, "args", []) if context is not None else []
|
|
||||||
if args:
|
|
||||||
handled = await self._bind_connection_from_start_token(update, str(args[0]))
|
|
||||||
if handled:
|
|
||||||
return
|
|
||||||
await update.message.reply_text("Welcome to DeerFlow! Send me a message to start a conversation.\nType /help for available commands.")
|
await update.message.reply_text("Welcome to DeerFlow! Send me a message to start a conversation.\nType /help for available commands.")
|
||||||
|
|
||||||
async def _process_incoming_with_reply(self, chat_id: str, msg_id: int, inbound: InboundMessage) -> None:
|
async def _process_incoming_with_reply(self, chat_id: str, msg_id: int, inbound: InboundMessage) -> None:
|
||||||
@@ -503,7 +243,7 @@ class TelegramChannel(Channel):
|
|||||||
if not self._check_user(update.effective_user.id):
|
if not self._check_user(update.effective_user.id):
|
||||||
return
|
return
|
||||||
|
|
||||||
text = self._strip_bot_username_from_leading_command(update.message.text.strip(), self._get_bot_username(context))
|
text = update.message.text
|
||||||
chat_id = str(update.effective_chat.id)
|
chat_id = str(update.effective_chat.id)
|
||||||
user_id = str(update.effective_user.id)
|
user_id = str(update.effective_user.id)
|
||||||
msg_id = str(update.message.message_id)
|
msg_id = str(update.message.message_id)
|
||||||
@@ -527,7 +267,6 @@ class TelegramChannel(Channel):
|
|||||||
thread_ts=msg_id,
|
thread_ts=msg_id,
|
||||||
)
|
)
|
||||||
inbound.topic_id = topic_id
|
inbound.topic_id = topic_id
|
||||||
inbound = await self._attach_connection_identity(inbound)
|
|
||||||
|
|
||||||
if self._main_loop and self._main_loop.is_running():
|
if self._main_loop and self._main_loop.is_running():
|
||||||
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
||||||
@@ -540,7 +279,7 @@ class TelegramChannel(Channel):
|
|||||||
if not self._check_user(update.effective_user.id):
|
if not self._check_user(update.effective_user.id):
|
||||||
return
|
return
|
||||||
|
|
||||||
text = self._strip_bot_username_from_leading_command(update.message.text.strip(), self._get_bot_username(context))
|
text = update.message.text.strip()
|
||||||
if not text:
|
if not text:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -570,7 +309,6 @@ class TelegramChannel(Channel):
|
|||||||
thread_ts=msg_id,
|
thread_ts=msg_id,
|
||||||
)
|
)
|
||||||
inbound.topic_id = topic_id
|
inbound.topic_id = topic_id
|
||||||
inbound = await self._attach_connection_identity(inbound)
|
|
||||||
|
|
||||||
if self._main_loop and self._main_loop.is_running():
|
if self._main_loop and self._main_loop.is_running():
|
||||||
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
||||||
|
|||||||
@@ -22,9 +22,7 @@ from cryptography.hazmat.primitives import padding
|
|||||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
from app.channels.connection_identity import attach_connection_identity
|
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -254,7 +252,6 @@ class WechatChannel(Channel):
|
|||||||
self._state_dir = self._resolve_state_dir(config.get("state_dir"))
|
self._state_dir = self._resolve_state_dir(config.get("state_dir"))
|
||||||
self._cursor_path = self._state_dir / "wechat-getupdates.json" if self._state_dir else None
|
self._cursor_path = self._state_dir / "wechat-getupdates.json" if self._state_dir else None
|
||||||
self._auth_path = self._state_dir / "wechat-auth.json" if self._state_dir else None
|
self._auth_path = self._state_dir / "wechat-auth.json" if self._state_dir else None
|
||||||
self._connection_repo = config.get("connection_repo")
|
|
||||||
self._load_state()
|
self._load_state()
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
@@ -619,21 +616,11 @@ class WechatChannel(Channel):
|
|||||||
if thread_ts:
|
if thread_ts:
|
||||||
self._context_tokens_by_thread[thread_ts] = context_token
|
self._context_tokens_by_thread[thread_ts] = context_token
|
||||||
|
|
||||||
connect_code = extract_connect_code(text)
|
|
||||||
if connect_code and self._connection_repo is not None:
|
|
||||||
handled = await self._bind_connection_from_connect_code(
|
|
||||||
chat_id=chat_id,
|
|
||||||
context_token=context_token,
|
|
||||||
code=connect_code,
|
|
||||||
)
|
|
||||||
if handled:
|
|
||||||
return
|
|
||||||
|
|
||||||
inbound = self._make_inbound(
|
inbound = self._make_inbound(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
user_id=chat_id,
|
user_id=chat_id,
|
||||||
text=text,
|
text=text,
|
||||||
msg_type=InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT,
|
msg_type=InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT,
|
||||||
thread_ts=thread_ts,
|
thread_ts=thread_ts,
|
||||||
files=files,
|
files=files,
|
||||||
metadata={
|
metadata={
|
||||||
@@ -644,54 +631,8 @@ class WechatChannel(Channel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
inbound.topic_id = None
|
inbound.topic_id = None
|
||||||
inbound = await self._attach_connection_identity(inbound)
|
|
||||||
await self.bus.publish_inbound(inbound)
|
await self.bus.publish_inbound(inbound)
|
||||||
|
|
||||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
|
||||||
return await attach_connection_identity(
|
|
||||||
inbound,
|
|
||||||
repo=self._connection_repo,
|
|
||||||
provider="wechat",
|
|
||||||
workspace_id=inbound.chat_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _bind_connection_from_connect_code(self, *, chat_id: str, context_token: str, code: str) -> bool:
|
|
||||||
if self._connection_repo is None or not code:
|
|
||||||
return False
|
|
||||||
|
|
||||||
state = await self._connection_repo.consume_oauth_state(provider="wechat", state=code)
|
|
||||||
if state is None:
|
|
||||||
await self._send_connection_reply(chat_id, context_token, "WeChat connection code is invalid or expired.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
if not chat_id:
|
|
||||||
await self._send_connection_reply(chat_id, context_token, "WeChat connection could not be completed from this message.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
await self._connection_repo.upsert_connection(
|
|
||||||
owner_user_id=state["owner_user_id"],
|
|
||||||
provider="wechat",
|
|
||||||
external_account_id=chat_id,
|
|
||||||
workspace_id=chat_id,
|
|
||||||
metadata={
|
|
||||||
"context_token": context_token,
|
|
||||||
},
|
|
||||||
status="connected",
|
|
||||||
)
|
|
||||||
await self._send_connection_reply(chat_id, context_token, "WeChat connected to DeerFlow.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _send_connection_reply(self, chat_id: str, context_token: str, text: str) -> None:
|
|
||||||
if not context_token:
|
|
||||||
return
|
|
||||||
await self._send_text_message(
|
|
||||||
chat_id=chat_id,
|
|
||||||
context_token=context_token,
|
|
||||||
text=text,
|
|
||||||
client_id_prefix="deerflow-connect",
|
|
||||||
max_retries=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _ensure_authenticated(self) -> bool:
|
async def _ensure_authenticated(self) -> bool:
|
||||||
async with self._auth_lock:
|
async with self._auth_lock:
|
||||||
if self._bot_token:
|
if self._bot_token:
|
||||||
|
|||||||
@@ -8,10 +8,7 @@ from collections.abc import Awaitable, Callable
|
|||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
|
||||||
from app.channels.connection_identity import attach_connection_identity
|
|
||||||
from app.channels.message_bus import (
|
from app.channels.message_bus import (
|
||||||
InboundMessage,
|
|
||||||
InboundMessageType,
|
InboundMessageType,
|
||||||
MessageBus,
|
MessageBus,
|
||||||
OutboundMessage,
|
OutboundMessage,
|
||||||
@@ -31,7 +28,6 @@ class WeComChannel(Channel):
|
|||||||
self._ws_frames: dict[str, dict[str, Any]] = {}
|
self._ws_frames: dict[str, dict[str, Any]] = {}
|
||||||
self._ws_stream_ids: dict[str, str] = {}
|
self._ws_stream_ids: dict[str, str] = {}
|
||||||
self._working_message = "Working on it..."
|
self._working_message = "Working on it..."
|
||||||
self._connection_repo = config.get("connection_repo")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supports_streaming(self) -> bool:
|
def supports_streaming(self) -> bool:
|
||||||
@@ -274,17 +270,7 @@ class WeComChannel(Channel):
|
|||||||
|
|
||||||
user_id = (body.get("from") or {}).get("userid")
|
user_id = (body.get("from") or {}).get("userid")
|
||||||
|
|
||||||
connect_code = extract_connect_code(text)
|
inbound_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
||||||
if connect_code and self._connection_repo is not None:
|
|
||||||
handled = await self._bind_connection_from_connect_code(
|
|
||||||
frame=frame,
|
|
||||||
user_id=str(user_id or ""),
|
|
||||||
code=connect_code,
|
|
||||||
)
|
|
||||||
if handled:
|
|
||||||
return
|
|
||||||
|
|
||||||
inbound_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
|
|
||||||
inbound = self._make_inbound(
|
inbound = self._make_inbound(
|
||||||
chat_id=user_id, # keep user's conversation in memory
|
chat_id=user_id, # keep user's conversation in memory
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -305,52 +291,8 @@ class WeComChannel(Channel):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
inbound = await self._attach_connection_identity(inbound)
|
|
||||||
await self.bus.publish_inbound(inbound)
|
await self.bus.publish_inbound(inbound)
|
||||||
|
|
||||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
|
||||||
return await attach_connection_identity(
|
|
||||||
inbound,
|
|
||||||
repo=self._connection_repo,
|
|
||||||
provider="wecom",
|
|
||||||
workspace_id=str(inbound.metadata.get("aibotid") or "") or None,
|
|
||||||
fallback_without_workspace=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _bind_connection_from_connect_code(self, *, frame: dict[str, Any], user_id: str, code: str) -> bool:
|
|
||||||
if self._connection_repo is None or not code:
|
|
||||||
return False
|
|
||||||
|
|
||||||
state = await self._connection_repo.consume_oauth_state(provider="wecom", state=code)
|
|
||||||
if state is None:
|
|
||||||
await self._send_connection_reply(frame, "WeCom connection code is invalid or expired.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
if not user_id:
|
|
||||||
await self._send_connection_reply(frame, "WeCom connection could not be completed from this message.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
body = frame.get("body", {}) or {}
|
|
||||||
workspace_id = str(body.get("aibotid") or "") or None
|
|
||||||
await self._connection_repo.upsert_connection(
|
|
||||||
owner_user_id=state["owner_user_id"],
|
|
||||||
provider="wecom",
|
|
||||||
external_account_id=user_id,
|
|
||||||
workspace_id=workspace_id,
|
|
||||||
metadata={
|
|
||||||
"aibotid": workspace_id,
|
|
||||||
"chattype": body.get("chattype"),
|
|
||||||
},
|
|
||||||
status="connected",
|
|
||||||
)
|
|
||||||
await self._send_connection_reply(frame, "WeCom connected to DeerFlow.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def _send_connection_reply(self, frame: dict[str, Any], text: str) -> None:
|
|
||||||
if not self._ws_client:
|
|
||||||
return
|
|
||||||
await self._ws_client.reply(frame, {"msgtype": "text", "text": {"content": text}})
|
|
||||||
|
|
||||||
async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||||||
if not self._ws_client:
|
if not self._ws_client:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from contextlib import asynccontextmanager
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from app.gateway.auth_disabled import warn_if_auth_disabled_enabled
|
|
||||||
from app.gateway.auth_middleware import AuthMiddleware
|
from app.gateway.auth_middleware import AuthMiddleware
|
||||||
from app.gateway.config import get_gateway_config
|
from app.gateway.config import get_gateway_config
|
||||||
from app.gateway.csrf_middleware import CSRFMiddleware, get_configured_cors_origins
|
from app.gateway.csrf_middleware import CSRFMiddleware, get_configured_cors_origins
|
||||||
@@ -16,7 +15,6 @@ from app.gateway.routers import (
|
|||||||
artifacts,
|
artifacts,
|
||||||
assistants_compat,
|
assistants_compat,
|
||||||
auth,
|
auth,
|
||||||
channel_connections,
|
|
||||||
channels,
|
channels,
|
||||||
feedback,
|
feedback,
|
||||||
mcp,
|
mcp,
|
||||||
@@ -174,7 +172,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
startup_config = get_app_config()
|
startup_config = get_app_config()
|
||||||
apply_logging_level(startup_config.log_level)
|
apply_logging_level(startup_config.log_level)
|
||||||
logger.info("Configuration loaded successfully")
|
logger.info("Configuration loaded successfully")
|
||||||
warn_if_auth_disabled_enabled()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Failed to load configuration during gateway startup: {e}"
|
error_msg = f"Failed to load configuration during gateway startup: {e}"
|
||||||
logger.exception(error_msg)
|
logger.exception(error_msg)
|
||||||
@@ -182,31 +179,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
config = get_gateway_config()
|
config = get_gateway_config()
|
||||||
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
|
logger.info(f"Starting API Gateway on {config.host}:{config.port}")
|
||||||
|
|
||||||
# Pre-warm tiktoken encoding cache so the first memory-injection request
|
|
||||||
# never blocks on the BPE data download (which hits an OpenAI/Azure URL
|
|
||||||
# that may be unreachable in restricted networks — see issue #3402).
|
|
||||||
# When memory.token_counting is "char", token counting never touches
|
|
||||||
# tiktoken, so skip the warm-up entirely (avoids even the 5s probe in
|
|
||||||
# network-restricted deployments — see issue #3429).
|
|
||||||
if startup_config.memory.token_counting == "char":
|
|
||||||
logger.info("memory.token_counting='char'; skipping tiktoken warm-up (network-free token estimation)")
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
from deerflow.agents.memory.prompt import warm_tiktoken_cache
|
|
||||||
|
|
||||||
warmed = await asyncio.wait_for(
|
|
||||||
asyncio.to_thread(warm_tiktoken_cache),
|
|
||||||
timeout=5,
|
|
||||||
)
|
|
||||||
if warmed:
|
|
||||||
logger.info("tiktoken encoding cache warmed successfully")
|
|
||||||
else:
|
|
||||||
logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback until tiktoken loads successfully")
|
|
||||||
except TimeoutError:
|
|
||||||
logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback until tiktoken loads successfully")
|
|
||||||
except Exception:
|
|
||||||
logger.warning("tiktoken warm-up skipped", exc_info=True)
|
|
||||||
|
|
||||||
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
|
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
|
||||||
async with langgraph_runtime(app, startup_config):
|
async with langgraph_runtime(app, startup_config):
|
||||||
logger.info("LangGraph runtime initialised")
|
logger.info("LangGraph runtime initialised")
|
||||||
@@ -385,9 +357,6 @@ This gateway provides runtime endpoints for agent runs plus custom endpoints for
|
|||||||
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
|
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
|
||||||
app.include_router(suggestions.router)
|
app.include_router(suggestions.router)
|
||||||
|
|
||||||
# User-facing IM channel connection API is mounted at /api/channels
|
|
||||||
app.include_router(channel_connections.router)
|
|
||||||
|
|
||||||
# Channels API is mounted at /api/channels
|
# Channels API is mounted at /api/channels
|
||||||
app.include_router(channels.router)
|
app.include_router(channels.router)
|
||||||
|
|
||||||
|
|||||||
@@ -1,56 +0,0 @@
|
|||||||
"""Shared helpers for local/E2E auth-disabled mode."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
|
||||||
|
|
||||||
AUTH_DISABLED_ENV_VAR = "DEER_FLOW_AUTH_DISABLED"
|
|
||||||
AUTH_DISABLED_USER_ID = DEFAULT_USER_ID
|
|
||||||
AUTH_DISABLED_USER_EMAIL = "default@test.local"
|
|
||||||
|
|
||||||
AUTH_SOURCE_SESSION = "session"
|
|
||||||
AUTH_SOURCE_INTERNAL = "internal"
|
|
||||||
AUTH_SOURCE_AUTH_DISABLED = "auth_disabled"
|
|
||||||
|
|
||||||
_PRODUCTION_ENV_VARS: tuple[str, ...] = ("DEER_FLOW_ENV", "ENVIRONMENT")
|
|
||||||
_PRODUCTION_ENV_VALUES: frozenset[str] = frozenset({"prod", "production"})
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def is_explicit_production_environment() -> bool:
|
|
||||||
return any(os.environ.get(name, "").strip().lower() in _PRODUCTION_ENV_VALUES for name in _PRODUCTION_ENV_VARS)
|
|
||||||
|
|
||||||
|
|
||||||
def is_auth_disabled_requested() -> bool:
|
|
||||||
return os.environ.get(AUTH_DISABLED_ENV_VAR) == "1"
|
|
||||||
|
|
||||||
|
|
||||||
def is_auth_disabled() -> bool:
|
|
||||||
return is_auth_disabled_requested() and not is_explicit_production_environment()
|
|
||||||
|
|
||||||
|
|
||||||
def warn_if_auth_disabled_enabled() -> None:
|
|
||||||
if not is_auth_disabled():
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
"%s=1 is active: authentication is bypassed and anonymous requests run as synthetic admin user %r. Do not enable this in shared or production deployments.",
|
|
||||||
AUTH_DISABLED_ENV_VAR,
|
|
||||||
AUTH_DISABLED_USER_ID,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_auth_disabled_user():
|
|
||||||
return SimpleNamespace(
|
|
||||||
id=AUTH_DISABLED_USER_ID,
|
|
||||||
email=AUTH_DISABLED_USER_EMAIL,
|
|
||||||
password_hash=None,
|
|
||||||
system_role="admin",
|
|
||||||
needs_setup=False,
|
|
||||||
token_version=0,
|
|
||||||
)
|
|
||||||
@@ -17,13 +17,6 @@ from starlette.responses import JSONResponse
|
|||||||
from starlette.types import ASGIApp
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
||||||
from app.gateway.auth_disabled import (
|
|
||||||
AUTH_SOURCE_AUTH_DISABLED,
|
|
||||||
AUTH_SOURCE_INTERNAL,
|
|
||||||
AUTH_SOURCE_SESSION,
|
|
||||||
get_auth_disabled_user,
|
|
||||||
is_auth_disabled,
|
|
||||||
)
|
|
||||||
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
|
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
|
||||||
from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token
|
from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token
|
||||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||||
@@ -87,14 +80,18 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)):
|
if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)):
|
||||||
internal_user = get_internal_user()
|
internal_user = get_internal_user()
|
||||||
|
|
||||||
auth_source = AUTH_SOURCE_SESSION
|
|
||||||
access_token = request.cookies.get("access_token")
|
|
||||||
|
|
||||||
# Non-public path: require session cookie
|
# Non-public path: require session cookie
|
||||||
if internal_user is not None:
|
if internal_user is None and not request.cookies.get("access_token"):
|
||||||
user = internal_user
|
return JSONResponse(
|
||||||
auth_source = AUTH_SOURCE_INTERNAL
|
status_code=401,
|
||||||
elif access_token:
|
content={
|
||||||
|
"detail": AuthErrorResponse(
|
||||||
|
code=AuthErrorCode.NOT_AUTHENTICATED,
|
||||||
|
message="Authentication required",
|
||||||
|
).model_dump()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Strict JWT validation: reject junk/expired tokens with 401
|
# Strict JWT validation: reject junk/expired tokens with 401
|
||||||
# right here instead of silently passing through. This closes
|
# right here instead of silently passing through. This closes
|
||||||
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
|
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
|
||||||
@@ -108,33 +105,19 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
# bubble up, so we catch and render it as JSONResponse here.
|
# bubble up, so we catch and render it as JSONResponse here.
|
||||||
from app.gateway.deps import get_current_user_from_request
|
from app.gateway.deps import get_current_user_from_request
|
||||||
|
|
||||||
|
if internal_user is not None:
|
||||||
|
user = internal_user
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
user = await get_current_user_from_request(request)
|
user = await get_current_user_from_request(request)
|
||||||
except HTTPException as exc:
|
except HTTPException as exc:
|
||||||
if not is_auth_disabled():
|
|
||||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||||
user = get_auth_disabled_user()
|
|
||||||
auth_source = AUTH_SOURCE_AUTH_DISABLED
|
|
||||||
elif is_auth_disabled():
|
|
||||||
user = get_auth_disabled_user()
|
|
||||||
auth_source = AUTH_SOURCE_AUTH_DISABLED
|
|
||||||
else:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=401,
|
|
||||||
content={
|
|
||||||
"detail": AuthErrorResponse(
|
|
||||||
code=AuthErrorCode.NOT_AUTHENTICATED,
|
|
||||||
message="Authentication required",
|
|
||||||
).model_dump()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stamp both request.state.user (for the contextvar pattern)
|
# Stamp both request.state.user (for the contextvar pattern)
|
||||||
# and request.state.auth (so @require_permission's "auth is
|
# and request.state.auth (so @require_permission's "auth is
|
||||||
# None" branch short-circuits instead of running the entire
|
# None" branch short-circuits instead of running the entire
|
||||||
# JWT-decode + DB-lookup pipeline a second time per request).
|
# JWT-decode + DB-lookup pipeline a second time per request).
|
||||||
request.state.user = user
|
request.state.user = user
|
||||||
request.state.auth_source = auth_source
|
|
||||||
request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
||||||
token = set_current_user(user)
|
token = set_current_user(user)
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -276,8 +276,6 @@ def require_permission(
|
|||||||
# strict-deny rather than strict-allow — only an *existing*
|
# strict-deny rather than strict-allow — only an *existing*
|
||||||
# row with a *different* user_id triggers 404.
|
# row with a *different* user_id triggers 404.
|
||||||
if owner_check:
|
if owner_check:
|
||||||
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
|
|
||||||
|
|
||||||
thread_id = kwargs.get("thread_id")
|
thread_id = kwargs.get("thread_id")
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||||
@@ -290,22 +288,6 @@ def require_permission(
|
|||||||
str(auth.user.id),
|
str(auth.user.id),
|
||||||
require_existing=require_existing,
|
require_existing=require_existing,
|
||||||
)
|
)
|
||||||
if not allowed and getattr(auth.user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
|
|
||||||
# Trusted internal callers (channel workers) also act for
|
|
||||||
# the connection owner carried in X-DeerFlow-Owner-User-Id.
|
|
||||||
# Scope the check to that owner instead of bypassing it; a
|
|
||||||
# leaked internal token must not grant cross-user thread
|
|
||||||
# access. The header is honored only after ``auth`` proved
|
|
||||||
# the caller holds the internal token (mirrors
|
|
||||||
# get_trusted_internal_owner_user_id, which keys off the
|
|
||||||
# middleware-stamped ``request.state.user``).
|
|
||||||
header_owner = (request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME) or "").strip()
|
|
||||||
if header_owner:
|
|
||||||
allowed = await thread_store.check_access(
|
|
||||||
thread_id,
|
|
||||||
header_owner,
|
|
||||||
require_existing=require_existing,
|
|
||||||
)
|
|
||||||
if not allowed:
|
if not allowed:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=404,
|
status_code=404,
|
||||||
|
|||||||
@@ -14,8 +14,6 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
|||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
from starlette.types import ASGIApp
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
from app.gateway.auth_disabled import is_auth_disabled
|
|
||||||
|
|
||||||
CSRF_COOKIE_NAME = "csrf_token"
|
CSRF_COOKIE_NAME = "csrf_token"
|
||||||
CSRF_HEADER_NAME = "X-CSRF-Token"
|
CSRF_HEADER_NAME = "X-CSRF-Token"
|
||||||
CSRF_TOKEN_LENGTH = 64 # bytes
|
CSRF_TOKEN_LENGTH = 64 # bytes
|
||||||
@@ -40,9 +38,6 @@ def should_check_csrf(request: Request) -> bool:
|
|||||||
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
|
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if is_auth_disabled():
|
|
||||||
return False
|
|
||||||
|
|
||||||
path = request.url.path.rstrip("/")
|
path = request.url.path.rstrip("/")
|
||||||
# Exempt /api/v1/auth/me endpoint
|
# Exempt /api/v1/auth/me endpoint
|
||||||
if path == "/api/v1/auth/me":
|
if path == "/api/v1/auth/me":
|
||||||
|
|||||||
@@ -331,17 +331,6 @@ async def get_current_user_from_request(request: Request):
|
|||||||
|
|
||||||
Raises HTTPException 401 if not authenticated.
|
Raises HTTPException 401 if not authenticated.
|
||||||
"""
|
"""
|
||||||
state = getattr(request, "state", None)
|
|
||||||
state_user = getattr(state, "user", None)
|
|
||||||
from app.gateway.auth_disabled import AUTH_SOURCE_AUTH_DISABLED, AUTH_SOURCE_INTERNAL, AUTH_SOURCE_SESSION
|
|
||||||
|
|
||||||
if state_user is not None and getattr(state, "auth_source", None) in {
|
|
||||||
AUTH_SOURCE_SESSION,
|
|
||||||
AUTH_SOURCE_AUTH_DISABLED,
|
|
||||||
AUTH_SOURCE_INTERNAL,
|
|
||||||
}:
|
|
||||||
return state_user
|
|
||||||
|
|
||||||
from app.gateway.auth import decode_token
|
from app.gateway.auth import decode_token
|
||||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
|
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
|
||||||
|
|
||||||
|
|||||||
@@ -5,12 +5,10 @@ from __future__ import annotations
|
|||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
||||||
|
|
||||||
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
|
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
|
||||||
INTERNAL_OWNER_USER_ID_HEADER_NAME = "X-DeerFlow-Owner-User-Id"
|
|
||||||
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
|
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
|
||||||
INTERNAL_SYSTEM_ROLE = "internal"
|
INTERNAL_SYSTEM_ROLE = "internal"
|
||||||
|
|
||||||
@@ -25,12 +23,9 @@ def _load_internal_auth_token() -> str:
|
|||||||
_INTERNAL_AUTH_TOKEN = _load_internal_auth_token()
|
_INTERNAL_AUTH_TOKEN = _load_internal_auth_token()
|
||||||
|
|
||||||
|
|
||||||
def create_internal_auth_headers(*, owner_user_id: str | None = None) -> dict[str, str]:
|
def create_internal_auth_headers() -> dict[str, str]:
|
||||||
"""Return headers that authenticate trusted Gateway internal calls."""
|
"""Return headers that authenticate trusted Gateway internal calls."""
|
||||||
headers = {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
|
return {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
|
||||||
if owner_user_id:
|
|
||||||
headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] = owner_user_id
|
|
||||||
return headers
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_internal_auth_token(token: str | None) -> bool:
|
def is_valid_internal_auth_token(token: str | None) -> bool:
|
||||||
@@ -41,21 +36,3 @@ def is_valid_internal_auth_token(token: str | None) -> bool:
|
|||||||
def get_internal_user():
|
def get_internal_user():
|
||||||
"""Return the synthetic user used for trusted internal channel calls."""
|
"""Return the synthetic user used for trusted internal channel calls."""
|
||||||
return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE)
|
return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE)
|
||||||
|
|
||||||
|
|
||||||
def get_trusted_internal_owner_user_id(request: Any) -> str | None:
|
|
||||||
"""Return the owner override for a trusted internal request, if present.
|
|
||||||
|
|
||||||
The header is ignored for normal browser/API callers. It is only honored
|
|
||||||
after ``AuthMiddleware`` has validated the internal auth token and stamped
|
|
||||||
the synthetic internal user onto ``request.state.user``.
|
|
||||||
"""
|
|
||||||
user = getattr(getattr(request, "state", None), "user", None)
|
|
||||||
if getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE:
|
|
||||||
return None
|
|
||||||
|
|
||||||
owner_user_id = request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME)
|
|
||||||
if not owner_user_id:
|
|
||||||
return None
|
|
||||||
owner_user_id = owner_user_id.strip()
|
|
||||||
return owner_user_id or None
|
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from langgraph_sdk import Auth
|
|||||||
|
|
||||||
from app.gateway.auth.errors import TokenError
|
from app.gateway.auth.errors import TokenError
|
||||||
from app.gateway.auth.jwt import decode_token
|
from app.gateway.auth.jwt import decode_token
|
||||||
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled
|
|
||||||
from app.gateway.deps import get_local_provider
|
from app.gateway.deps import get_local_provider
|
||||||
|
|
||||||
auth = Auth()
|
auth = Auth()
|
||||||
@@ -39,9 +38,6 @@ def _check_csrf(request) -> None:
|
|||||||
if method.upper() not in _CSRF_METHODS:
|
if method.upper() not in _CSRF_METHODS:
|
||||||
return
|
return
|
||||||
|
|
||||||
if is_auth_disabled():
|
|
||||||
return
|
|
||||||
|
|
||||||
cookie_token = request.cookies.get("csrf_token")
|
cookie_token = request.cookies.get("csrf_token")
|
||||||
header_token = request.headers.get("x-csrf-token")
|
header_token = request.headers.get("x-csrf-token")
|
||||||
|
|
||||||
@@ -70,9 +66,6 @@ async def authenticate(request):
|
|||||||
# are rejected early, even if the cookie carries a valid JWT.
|
# are rejected early, even if the cookie carries a valid JWT.
|
||||||
_check_csrf(request)
|
_check_csrf(request)
|
||||||
|
|
||||||
if is_auth_disabled():
|
|
||||||
return AUTH_DISABLED_USER_ID
|
|
||||||
|
|
||||||
token = request.cookies.get("access_token")
|
token = request.cookies.get("access_token")
|
||||||
if not token:
|
if not token:
|
||||||
raise Auth.exceptions.HTTPException(
|
raise Auth.exceptions.HTTPException(
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""CRUD API for custom agents."""
|
"""CRUD API for custom agents."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
@@ -214,21 +213,15 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
|
|||||||
user_id = get_effective_user_id()
|
user_id = get_effective_user_id()
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
|
|
||||||
def _create_agent() -> AgentResponse | None:
|
|
||||||
# Worker thread: base-dir resolution, existence checks, directory/file
|
|
||||||
# creation, read-back, and failure cleanup are all blocking filesystem
|
|
||||||
# IO that must stay off the event loop.
|
|
||||||
agent_dir = paths.user_agent_dir(user_id, normalized_name)
|
agent_dir = paths.user_agent_dir(user_id, normalized_name)
|
||||||
legacy_dir = paths.agent_dir(normalized_name)
|
legacy_dir = paths.agent_dir(normalized_name)
|
||||||
|
|
||||||
if legacy_dir.exists():
|
if agent_dir.exists() or legacy_dir.exists():
|
||||||
return None # signals 409 to the caller
|
raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||||
agent_dir.mkdir(parents=True, exist_ok=False)
|
|
||||||
except FileExistsError:
|
|
||||||
return None # signals 409 to the caller
|
|
||||||
# Write config.yaml
|
# Write config.yaml
|
||||||
config_data: dict = {"name": normalized_name}
|
config_data: dict = {"name": normalized_name}
|
||||||
if request.description:
|
if request.description:
|
||||||
@@ -252,23 +245,16 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
|
|||||||
|
|
||||||
agent_cfg = load_agent_config(normalized_name, user_id=user_id)
|
agent_cfg = load_agent_config(normalized_name, user_id=user_id)
|
||||||
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
|
return _agent_config_to_response(agent_cfg, include_soul=True, user_id=user_id)
|
||||||
except Exception:
|
|
||||||
# Clean up partial state on failure before surfacing the error.
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# Clean up on failure
|
||||||
if agent_dir.exists():
|
if agent_dir.exists():
|
||||||
shutil.rmtree(agent_dir)
|
shutil.rmtree(agent_dir)
|
||||||
raise
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await asyncio.to_thread(_create_agent)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to create agent '{request.name}': {e}", exc_info=True)
|
logger.error(f"Failed to create agent '{request.name}': {e}", exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to create agent: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Failed to create agent: {str(e)}")
|
||||||
|
|
||||||
if response is None:
|
|
||||||
raise HTTPException(status_code=409, detail=f"Agent '{normalized_name}' already exists")
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
@router.put(
|
@router.put(
|
||||||
"/agents/{name}",
|
"/agents/{name}",
|
||||||
@@ -442,30 +428,19 @@ async def delete_agent(name: str) -> None:
|
|||||||
name = _normalize_agent_name(name)
|
name = _normalize_agent_name(name)
|
||||||
user_id = get_effective_user_id()
|
user_id = get_effective_user_id()
|
||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
|
|
||||||
def _remove_agent_dir() -> tuple[str, str]:
|
|
||||||
# Runs in a worker thread: resolving the base dir, probing the directory
|
|
||||||
# (`exists`), and removing it (`rmtree`) are all blocking filesystem IO
|
|
||||||
# that must stay off the event loop.
|
|
||||||
agent_dir = paths.user_agent_dir(user_id, name)
|
agent_dir = paths.user_agent_dir(user_id, name)
|
||||||
|
|
||||||
if not agent_dir.exists():
|
if not agent_dir.exists():
|
||||||
outcome = "legacy" if paths.agent_dir(name).exists() else "missing"
|
if paths.agent_dir(name).exists():
|
||||||
return outcome, str(agent_dir)
|
|
||||||
shutil.rmtree(agent_dir)
|
|
||||||
return "deleted", str(agent_dir)
|
|
||||||
|
|
||||||
try:
|
|
||||||
outcome, agent_dir = await asyncio.to_thread(_remove_agent_dir)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to delete agent '{name}': {e}", exc_info=True)
|
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to delete agent: {str(e)}")
|
|
||||||
|
|
||||||
if outcome == "legacy":
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=409,
|
status_code=409,
|
||||||
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before deleting."),
|
detail=(f"Agent '{name}' only exists in the legacy shared layout and is not scoped to a user. Run scripts/migrate_user_isolation.py to move legacy agents into the per-user layout before deleting."),
|
||||||
)
|
)
|
||||||
if outcome == "missing":
|
|
||||||
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
raise HTTPException(status_code=404, detail=f"Agent '{name}' not found")
|
||||||
|
|
||||||
|
try:
|
||||||
|
shutil.rmtree(agent_dir)
|
||||||
logger.info(f"Deleted agent '{name}' from {agent_dir}")
|
logger.info(f"Deleted agent '{name}' from {agent_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete agent '{name}': {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to delete agent: {str(e)}")
|
||||||
|
|||||||
@@ -341,19 +341,9 @@ async def change_password(request: Request, response: Response, body: ChangePass
|
|||||||
- Re-issues session cookie with new token_version
|
- Re-issues session cookie with new token_version
|
||||||
"""
|
"""
|
||||||
from app.gateway.auth.password import hash_password_async, verify_password_async
|
from app.gateway.auth.password import hash_password_async, verify_password_async
|
||||||
from app.gateway.auth_disabled import AUTH_SOURCE_AUTH_DISABLED
|
|
||||||
|
|
||||||
user = await get_current_user_from_request(request)
|
user = await get_current_user_from_request(request)
|
||||||
|
|
||||||
if getattr(request.state, "auth_source", None) == AUTH_SOURCE_AUTH_DISABLED:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=AuthErrorResponse(
|
|
||||||
code=AuthErrorCode.INVALID_CREDENTIALS,
|
|
||||||
message="Password changes are not available when DEER_FLOW_AUTH_DISABLED=1.",
|
|
||||||
).model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if user.password_hash is None:
|
if user.password_hash is None:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
|
||||||
|
|
||||||
|
|||||||
@@ -1,670 +0,0 @@
|
|||||||
"""Browser-facing APIs for user-owned IM channel bindings."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import secrets
|
|
||||||
from datetime import UTC, datetime, timedelta
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request, Response
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from app.channels.runtime_config_store import (
|
|
||||||
ChannelRuntimeConfigStore,
|
|
||||||
apply_runtime_connection_config,
|
|
||||||
merge_runtime_channel_configs,
|
|
||||||
)
|
|
||||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
|
||||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
|
||||||
from deerflow.persistence.engine import get_session_factory
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/channels", tags=["channel-connections"])
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_STATE_TTL_SECONDS = 600
|
|
||||||
_MASKED_CREDENTIAL_VALUE = "********"
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelCredentialFieldResponse(BaseModel):
|
|
||||||
name: str
|
|
||||||
label: str
|
|
||||||
type: str = "text"
|
|
||||||
required: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelProviderResponse(BaseModel):
|
|
||||||
provider: str
|
|
||||||
display_name: str
|
|
||||||
enabled: bool
|
|
||||||
configured: bool
|
|
||||||
connectable: bool
|
|
||||||
unavailable_reason: str | None = None
|
|
||||||
auth_mode: str
|
|
||||||
connection_status: str
|
|
||||||
credential_fields: list[ChannelCredentialFieldResponse] = Field(default_factory=list)
|
|
||||||
credential_values: dict[str, str] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelProvidersResponse(BaseModel):
|
|
||||||
enabled: bool
|
|
||||||
providers: list[ChannelProviderResponse]
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelConnectionResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
provider: str
|
|
||||||
status: str
|
|
||||||
external_account_id: str | None = None
|
|
||||||
external_account_name: str | None = None
|
|
||||||
workspace_id: str | None = None
|
|
||||||
workspace_name: str | None = None
|
|
||||||
scopes: list[str] = Field(default_factory=list)
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelConnectionsResponse(BaseModel):
|
|
||||||
connections: list[ChannelConnectionResponse]
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelConnectResponse(BaseModel):
|
|
||||||
provider: str
|
|
||||||
mode: str
|
|
||||||
url: str | None = None
|
|
||||||
code: str
|
|
||||||
instruction: str
|
|
||||||
expires_in: int
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelRuntimeConfigRequest(BaseModel):
|
|
||||||
values: dict[str, str] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
_PROVIDER_META: dict[str, dict[str, str]] = {
|
|
||||||
"telegram": {"display_name": "Telegram", "auth_mode": "deep_link"},
|
|
||||||
"slack": {"display_name": "Slack", "auth_mode": "binding_code"},
|
|
||||||
"discord": {"display_name": "Discord", "auth_mode": "binding_code"},
|
|
||||||
"feishu": {"display_name": "Feishu", "auth_mode": "binding_code"},
|
|
||||||
"dingtalk": {"display_name": "DingTalk", "auth_mode": "binding_code"},
|
|
||||||
"wechat": {"display_name": "WeChat", "auth_mode": "binding_code"},
|
|
||||||
"wecom": {"display_name": "WeCom", "auth_mode": "binding_code"},
|
|
||||||
}
|
|
||||||
|
|
||||||
_CREDENTIAL_FIELDS: dict[str, tuple[dict[str, str], ...]] = {
|
|
||||||
"telegram": (
|
|
||||||
{"name": "bot_token", "label": "Bot token", "type": "password"},
|
|
||||||
{"name": "bot_username", "label": "Bot username", "type": "text"},
|
|
||||||
),
|
|
||||||
"slack": (
|
|
||||||
{"name": "bot_token", "label": "Bot token", "type": "password"},
|
|
||||||
{"name": "app_token", "label": "App token", "type": "password"},
|
|
||||||
),
|
|
||||||
"discord": ({"name": "bot_token", "label": "Bot token", "type": "password"},),
|
|
||||||
"feishu": (
|
|
||||||
{"name": "app_id", "label": "App ID", "type": "text"},
|
|
||||||
{"name": "app_secret", "label": "App secret", "type": "password"},
|
|
||||||
),
|
|
||||||
"dingtalk": (
|
|
||||||
{"name": "client_id", "label": "Client ID", "type": "text"},
|
|
||||||
{"name": "client_secret", "label": "Client secret", "type": "password"},
|
|
||||||
),
|
|
||||||
"wechat": ({"name": "bot_token", "label": "Bot token", "type": "password"},),
|
|
||||||
"wecom": (
|
|
||||||
{"name": "bot_id", "label": "Bot ID", "type": "text"},
|
|
||||||
{"name": "bot_secret", "label": "Bot secret", "type": "password"},
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
_RUNTIME_REQUIREMENTS: dict[str, tuple[str, ...]] = {
|
|
||||||
"telegram": ("bot_token",),
|
|
||||||
"slack": ("bot_token", "app_token"),
|
|
||||||
"discord": ("bot_token",),
|
|
||||||
"feishu": ("app_id", "app_secret"),
|
|
||||||
"dingtalk": ("client_id", "client_secret"),
|
|
||||||
"wechat": ("bot_token",),
|
|
||||||
"wecom": ("bot_id", "bot_secret"),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_user_id(request: Request) -> str:
|
|
||||||
user = getattr(request.state, "user", None)
|
|
||||||
if user is None:
|
|
||||||
raise HTTPException(status_code=401, detail="Authentication required")
|
|
||||||
return str(user.id)
|
|
||||||
|
|
||||||
|
|
||||||
async def _require_admin_user(request: Request) -> None:
|
|
||||||
"""Require an admin caller for instance-wide channel runtime mutations.
|
|
||||||
|
|
||||||
Runtime credentials and the channel workers they start/stop are shared by
|
|
||||||
every user of the deployment, so only admins may change them (same model
|
|
||||||
as the MCP config API). Auth-disabled local mode uses a synthetic admin
|
|
||||||
user and is unaffected.
|
|
||||||
"""
|
|
||||||
user = getattr(request.state, "user", None)
|
|
||||||
if user is None:
|
|
||||||
from app.gateway.deps import get_current_user_from_request
|
|
||||||
|
|
||||||
user = await get_current_user_from_request(request)
|
|
||||||
|
|
||||||
if getattr(user, "system_role", None) != "admin":
|
|
||||||
raise HTTPException(status_code=403, detail="Admin privileges required to manage channel runtime credentials.")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_app_config():
|
|
||||||
from deerflow.config.app_config import get_app_config
|
|
||||||
|
|
||||||
return get_app_config()
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_runtime_config_store(request: Request) -> ChannelRuntimeConfigStore:
|
|
||||||
store = getattr(request.app.state, "channel_runtime_config_store", None)
|
|
||||||
if isinstance(store, ChannelRuntimeConfigStore):
|
|
||||||
return store
|
|
||||||
# Constructing the store reads its JSON file from disk; keep it off the
|
|
||||||
# event loop.
|
|
||||||
store = await asyncio.to_thread(ChannelRuntimeConfigStore)
|
|
||||||
request.app.state.channel_runtime_config_store = store
|
|
||||||
return store
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_channel_connections_config(request: Request) -> ChannelConnectionsConfig:
|
|
||||||
config = getattr(request.app.state, "channel_connections_config", None)
|
|
||||||
if not isinstance(config, ChannelConnectionsConfig):
|
|
||||||
config = _get_app_config().channel_connections
|
|
||||||
config = apply_runtime_connection_config(config, store=await _get_runtime_config_store(request))
|
|
||||||
request.app.state.channel_connections_config = config
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_channels_config(request: Request) -> dict[str, Any]:
|
|
||||||
state_config = getattr(request.app.state, "channels_config", None)
|
|
||||||
if isinstance(state_config, dict):
|
|
||||||
return state_config
|
|
||||||
|
|
||||||
result = await _load_channels_config(request, await _get_channel_connections_config(request))
|
|
||||||
request.app.state.channels_config = result
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
async def _load_channels_config(request: Request, config: ChannelConnectionsConfig) -> dict[str, Any]:
|
|
||||||
app_config = _get_app_config()
|
|
||||||
extra = app_config.model_extra or {}
|
|
||||||
channels_config = extra.get("channels")
|
|
||||||
result = dict(channels_config) if isinstance(channels_config, dict) else {}
|
|
||||||
merge_runtime_channel_configs(
|
|
||||||
result,
|
|
||||||
config,
|
|
||||||
store=await _get_runtime_config_store(request),
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _get_repository(request: Request, config: ChannelConnectionsConfig) -> ChannelConnectionRepository:
|
|
||||||
repo = getattr(request.app.state, "channel_connection_repo", None)
|
|
||||||
if isinstance(repo, ChannelConnectionRepository):
|
|
||||||
return repo
|
|
||||||
|
|
||||||
sf = get_session_factory()
|
|
||||||
if sf is None:
|
|
||||||
raise HTTPException(status_code=503, detail="Channel connection persistence is not available")
|
|
||||||
|
|
||||||
repo = ChannelConnectionRepository(sf)
|
|
||||||
request.app.state.channel_connection_repo = repo
|
|
||||||
return repo
|
|
||||||
|
|
||||||
|
|
||||||
def _provider_config(config: ChannelConnectionsConfig, provider: str):
|
|
||||||
provider_config = getattr(config, provider, None)
|
|
||||||
if provider_config is None:
|
|
||||||
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
|
||||||
return provider_config
|
|
||||||
|
|
||||||
|
|
||||||
def _runtime_channel_configured(provider: str, channels_config: dict[str, Any]) -> bool:
|
|
||||||
runtime_config = channels_config.get(provider)
|
|
||||||
if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False):
|
|
||||||
return False
|
|
||||||
return all(str(runtime_config.get(key) or "").strip() for key in _RUNTIME_REQUIREMENTS[provider])
|
|
||||||
|
|
||||||
|
|
||||||
def _runtime_unavailable_reason(provider: str) -> str:
|
|
||||||
meta = _PROVIDER_META.get(provider)
|
|
||||||
display_name = meta["display_name"] if meta else provider
|
|
||||||
return f"Enter the required {display_name} credentials to connect this channel."
|
|
||||||
|
|
||||||
|
|
||||||
def _runtime_not_running_reason(provider: str) -> str:
|
|
||||||
meta = _PROVIDER_META.get(provider)
|
|
||||||
display_name = meta["display_name"] if meta else provider
|
|
||||||
return f"{display_name} channel is configured but is not running. Check the credentials and service logs."
|
|
||||||
|
|
||||||
|
|
||||||
def _runtime_channel_running(provider: str) -> bool | None:
|
|
||||||
try:
|
|
||||||
from app.channels.service import get_channel_service
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Unable to inspect channel service status", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
service = get_channel_service()
|
|
||||||
if service is None:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
status = service.get_status()
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Unable to read channel service status", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not status.get("service_running"):
|
|
||||||
return False
|
|
||||||
channel_status = status.get("channels", {}).get(provider)
|
|
||||||
if not isinstance(channel_status, dict):
|
|
||||||
return None
|
|
||||||
return bool(channel_status.get("running"))
|
|
||||||
|
|
||||||
|
|
||||||
async def _ensure_runtime_channel_ready_if_available(
|
|
||||||
provider: str,
|
|
||||||
channels_config: dict[str, Any],
|
|
||||||
) -> bool | None:
|
|
||||||
runtime_config = channels_config.get(provider)
|
|
||||||
if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False):
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
from app.channels.service import get_channel_service
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Unable to import channel service for readiness reconciliation", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
service = get_channel_service()
|
|
||||||
if service is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
ensure_channel_ready = getattr(service, "ensure_channel_ready", None)
|
|
||||||
if ensure_channel_ready is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
return await ensure_channel_ready(provider, runtime_config)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to reconcile runtime channel readiness")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _provider_unavailable_reason(
|
|
||||||
config: ChannelConnectionsConfig,
|
|
||||||
channels_config: dict[str, Any],
|
|
||||||
provider: str,
|
|
||||||
) -> str | None:
|
|
||||||
provider_config = _provider_config(config, provider)
|
|
||||||
if not provider_config.enabled:
|
|
||||||
return None
|
|
||||||
if not provider_config.configured:
|
|
||||||
return _runtime_unavailable_reason(provider)
|
|
||||||
if not _runtime_channel_configured(provider, channels_config):
|
|
||||||
return _runtime_unavailable_reason(provider)
|
|
||||||
if _runtime_channel_running(provider) is False:
|
|
||||||
return _runtime_not_running_reason(provider)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _provider_status(
|
|
||||||
config: ChannelConnectionsConfig,
|
|
||||||
channels_config: dict[str, Any],
|
|
||||||
provider: str,
|
|
||||||
) -> tuple[dict[str, bool], str | None]:
|
|
||||||
declared = config.provider_status(provider)
|
|
||||||
unavailable_reason = _provider_unavailable_reason(config, channels_config, provider)
|
|
||||||
configured = declared["configured"] and _runtime_channel_configured(provider, channels_config)
|
|
||||||
return {"enabled": declared["enabled"], "configured": configured}, unavailable_reason
|
|
||||||
|
|
||||||
|
|
||||||
def _new_binding_code() -> str:
|
|
||||||
return secrets.token_urlsafe(16)
|
|
||||||
|
|
||||||
|
|
||||||
async def _create_state(
|
|
||||||
repo: ChannelConnectionRepository,
|
|
||||||
*,
|
|
||||||
owner_user_id: str,
|
|
||||||
provider: str,
|
|
||||||
) -> str:
|
|
||||||
state = _new_binding_code()
|
|
||||||
await repo.create_oauth_state(
|
|
||||||
owner_user_id=owner_user_id,
|
|
||||||
provider=provider,
|
|
||||||
state=state,
|
|
||||||
expires_at=datetime.now(UTC) + timedelta(seconds=_STATE_TTL_SECONDS),
|
|
||||||
)
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
def _connect_instruction(provider: str, code: str) -> str:
|
|
||||||
if provider == "telegram":
|
|
||||||
return f"Send /start {code} to the DeerFlow Telegram bot."
|
|
||||||
meta = _PROVIDER_META.get(provider)
|
|
||||||
if meta is None:
|
|
||||||
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
|
||||||
return f"Send /connect {code} to the DeerFlow {meta['display_name']} bot."
|
|
||||||
|
|
||||||
|
|
||||||
def _connect_url(config: ChannelConnectionsConfig, provider: str, code: str) -> str | None:
|
|
||||||
if provider == "telegram":
|
|
||||||
provider_config = _provider_config(config, provider)
|
|
||||||
return f"https://t.me/{provider_config.bot_username}?start={code}"
|
|
||||||
if _PROVIDER_META.get(provider, {}).get("auth_mode") == "binding_code":
|
|
||||||
return None
|
|
||||||
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
|
||||||
|
|
||||||
|
|
||||||
def _connection_updated_at(connection: dict[str, Any]) -> datetime:
|
|
||||||
value = connection.get("updated_at")
|
|
||||||
if isinstance(value, datetime):
|
|
||||||
return value if value.tzinfo is not None else value.replace(tzinfo=UTC)
|
|
||||||
if isinstance(value, str) and value:
|
|
||||||
try:
|
|
||||||
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
return datetime.min.replace(tzinfo=UTC)
|
|
||||||
|
|
||||||
|
|
||||||
def _newest_connection_by_provider(connections: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
|
||||||
by_provider: dict[str, dict[str, Any]] = {}
|
|
||||||
for item in connections:
|
|
||||||
existing = by_provider.get(item["provider"])
|
|
||||||
if existing is None or _connection_updated_at(item) > _connection_updated_at(existing):
|
|
||||||
by_provider[item["provider"]] = item
|
|
||||||
return by_provider
|
|
||||||
|
|
||||||
|
|
||||||
def _credential_fields(provider: str) -> list[ChannelCredentialFieldResponse]:
|
|
||||||
fields = _CREDENTIAL_FIELDS.get(provider)
|
|
||||||
if fields is None:
|
|
||||||
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
|
||||||
return [ChannelCredentialFieldResponse(**field) for field in fields]
|
|
||||||
|
|
||||||
|
|
||||||
def _credential_values(provider: str, channels_config: dict[str, Any]) -> dict[str, str]:
|
|
||||||
runtime_config = channels_config.get(provider)
|
|
||||||
if not isinstance(runtime_config, dict):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
values: dict[str, str] = {}
|
|
||||||
for field in _credential_fields(provider):
|
|
||||||
value = str(runtime_config.get(field.name) or "").strip()
|
|
||||||
if not value:
|
|
||||||
continue
|
|
||||||
values[field.name] = _MASKED_CREDENTIAL_VALUE if field.type == "password" else value
|
|
||||||
return values
|
|
||||||
|
|
||||||
|
|
||||||
def _provider_response(
|
|
||||||
config: ChannelConnectionsConfig,
|
|
||||||
channels_config: dict[str, Any],
|
|
||||||
provider: str,
|
|
||||||
meta: dict[str, str],
|
|
||||||
connection: dict[str, Any] | None = None,
|
|
||||||
) -> ChannelProviderResponse:
|
|
||||||
from app.gateway.auth_disabled import is_auth_disabled
|
|
||||||
|
|
||||||
status, unavailable_reason = _provider_status(config, channels_config, provider)
|
|
||||||
if connection:
|
|
||||||
connection_status = connection["status"]
|
|
||||||
elif is_auth_disabled() and status["configured"] and unavailable_reason is None:
|
|
||||||
# Auth-disabled local mode routes every channel message to the default
|
|
||||||
# user, so a configured running channel needs no per-user binding.
|
|
||||||
connection_status = "connected"
|
|
||||||
else:
|
|
||||||
connection_status = "not_connected"
|
|
||||||
credential_values = _credential_values(provider, channels_config)
|
|
||||||
if provider == "telegram" and not credential_values.get("bot_username"):
|
|
||||||
bot_username = str(_provider_config(config, provider).bot_username or "").strip()
|
|
||||||
if bot_username:
|
|
||||||
credential_values["bot_username"] = bot_username
|
|
||||||
return ChannelProviderResponse(
|
|
||||||
provider=provider,
|
|
||||||
display_name=meta["display_name"],
|
|
||||||
enabled=status["enabled"],
|
|
||||||
configured=status["configured"],
|
|
||||||
connectable=status["enabled"] and status["configured"] and unavailable_reason is None,
|
|
||||||
unavailable_reason=unavailable_reason,
|
|
||||||
auth_mode=meta["auth_mode"],
|
|
||||||
connection_status=connection_status,
|
|
||||||
credential_fields=_credential_fields(provider),
|
|
||||||
credential_values=credential_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _required_runtime_values(
|
|
||||||
provider: str,
|
|
||||||
values: dict[str, str],
|
|
||||||
existing_config: dict[str, Any] | None = None,
|
|
||||||
) -> dict[str, str]:
|
|
||||||
fields = _credential_fields(provider)
|
|
||||||
cleaned: dict[str, str] = {}
|
|
||||||
missing: list[str] = []
|
|
||||||
existing_config = existing_config or {}
|
|
||||||
for field in fields:
|
|
||||||
raw_value = values.get(field.name, "")
|
|
||||||
if field.type == "password" and raw_value == _MASKED_CREDENTIAL_VALUE:
|
|
||||||
existing_value = str(existing_config.get(field.name) or "").strip()
|
|
||||||
if existing_value:
|
|
||||||
cleaned[field.name] = existing_value
|
|
||||||
continue
|
|
||||||
value = raw_value.strip() if isinstance(raw_value, str) else str(raw_value or "").strip()
|
|
||||||
if field.required and not value:
|
|
||||||
missing.append(field.label)
|
|
||||||
cleaned[field.name] = value
|
|
||||||
if missing:
|
|
||||||
raise HTTPException(status_code=400, detail=f"Missing required channel configuration: {', '.join(missing)}")
|
|
||||||
return cleaned
|
|
||||||
|
|
||||||
|
|
||||||
async def _restart_runtime_channel_if_available(provider: str, runtime_config: dict[str, Any]) -> bool | None:
|
|
||||||
try:
|
|
||||||
from app.channels.service import get_channel_service
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to import channel service while configuring a runtime channel")
|
|
||||||
return None
|
|
||||||
|
|
||||||
service = get_channel_service()
|
|
||||||
if service is None:
|
|
||||||
return None
|
|
||||||
return await service.configure_channel(provider, runtime_config)
|
|
||||||
|
|
||||||
|
|
||||||
async def _sync_runtime_channel_after_removal(provider: str, channels_config: dict[str, Any]) -> bool | None:
|
|
||||||
try:
|
|
||||||
from app.channels.service import get_channel_service
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to import channel service while disconnecting a runtime channel")
|
|
||||||
return None
|
|
||||||
|
|
||||||
service = get_channel_service()
|
|
||||||
if service is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
runtime_config = channels_config.get(provider)
|
|
||||||
if isinstance(runtime_config, dict) and runtime_config.get("enabled", False):
|
|
||||||
return await service.configure_channel(provider, runtime_config)
|
|
||||||
return await service.remove_channel(provider)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/providers", response_model=ChannelProvidersResponse)
|
|
||||||
async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
|
||||||
config = await _get_channel_connections_config(request)
|
|
||||||
channels_config = await _get_channels_config(request)
|
|
||||||
repo = None
|
|
||||||
if config.enabled:
|
|
||||||
try:
|
|
||||||
repo = _get_repository(request, config)
|
|
||||||
except HTTPException as exc:
|
|
||||||
if exc.status_code != 503:
|
|
||||||
raise
|
|
||||||
owner_user_id = _get_user_id(request)
|
|
||||||
connections = await repo.list_connections(owner_user_id) if repo is not None else []
|
|
||||||
by_provider = _newest_connection_by_provider(connections)
|
|
||||||
|
|
||||||
enabled_providers = [provider for provider in _PROVIDER_META if config.provider_status(provider)["enabled"]]
|
|
||||||
# Readiness reconciliation is independent per provider; run it
|
|
||||||
# concurrently so one slow channel restart does not serialize the
|
|
||||||
# whole /providers response.
|
|
||||||
await asyncio.gather(
|
|
||||||
*(_ensure_runtime_channel_ready_if_available(provider, channels_config) for provider in enabled_providers if _runtime_channel_configured(provider, channels_config)),
|
|
||||||
)
|
|
||||||
|
|
||||||
providers: list[ChannelProviderResponse] = []
|
|
||||||
for provider in enabled_providers:
|
|
||||||
connection = by_provider.get(provider)
|
|
||||||
providers.append(_provider_response(config, channels_config, provider, _PROVIDER_META[provider], connection))
|
|
||||||
return ChannelProvidersResponse(enabled=config.enabled, providers=providers)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/connections", response_model=ChannelConnectionsResponse)
|
|
||||||
async def get_channel_connections(request: Request) -> ChannelConnectionsResponse:
|
|
||||||
config = await _get_channel_connections_config(request)
|
|
||||||
if not config.enabled:
|
|
||||||
return ChannelConnectionsResponse(connections=[])
|
|
||||||
repo = _get_repository(request, config)
|
|
||||||
rows = await repo.list_connections(_get_user_id(request))
|
|
||||||
return ChannelConnectionsResponse(connections=[ChannelConnectionResponse(**row) for row in rows])
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/connections/{connection_id}", status_code=204)
|
|
||||||
async def disconnect_channel_connection(connection_id: str, request: Request) -> Response:
|
|
||||||
config = await _get_channel_connections_config(request)
|
|
||||||
if not config.enabled:
|
|
||||||
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
|
||||||
|
|
||||||
repo = _get_repository(request, config)
|
|
||||||
disconnected = await repo.disconnect_connection(
|
|
||||||
connection_id=connection_id,
|
|
||||||
owner_user_id=_get_user_id(request),
|
|
||||||
)
|
|
||||||
if not disconnected:
|
|
||||||
raise HTTPException(status_code=404, detail="Channel connection not found")
|
|
||||||
return Response(status_code=204)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{provider}/runtime-config", response_model=ChannelProviderResponse)
|
|
||||||
async def disconnect_channel_provider_runtime(provider: str, request: Request) -> ChannelProviderResponse:
|
|
||||||
await _require_admin_user(request)
|
|
||||||
config = await _get_channel_connections_config(request)
|
|
||||||
if not config.enabled:
|
|
||||||
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
|
||||||
|
|
||||||
provider_config = _provider_config(config, provider)
|
|
||||||
if not provider_config.enabled:
|
|
||||||
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
|
||||||
|
|
||||||
owner_user_id = _get_user_id(request)
|
|
||||||
try:
|
|
||||||
repo = _get_repository(request, config)
|
|
||||||
except HTTPException as exc:
|
|
||||||
if exc.status_code != 503:
|
|
||||||
raise
|
|
||||||
repo = None
|
|
||||||
|
|
||||||
if repo is not None:
|
|
||||||
for connection in await repo.list_connections(owner_user_id):
|
|
||||||
if connection["provider"] == provider and connection["status"] != "revoked":
|
|
||||||
await repo.disconnect_connection(
|
|
||||||
connection_id=connection["id"],
|
|
||||||
owner_user_id=owner_user_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
store = await _get_runtime_config_store(request)
|
|
||||||
await asyncio.to_thread(store.set_provider_disconnected, provider)
|
|
||||||
channels_config = await _load_channels_config(request, config)
|
|
||||||
request.app.state.channels_config = channels_config
|
|
||||||
|
|
||||||
stopped = await _sync_runtime_channel_after_removal(provider, channels_config)
|
|
||||||
if stopped is False:
|
|
||||||
display_name = _PROVIDER_META[provider]["display_name"]
|
|
||||||
raise HTTPException(status_code=400, detail=f"Failed to stop {display_name} channel. Try again.")
|
|
||||||
|
|
||||||
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{provider}/connect", response_model=ChannelConnectResponse)
|
|
||||||
async def connect_channel_provider(provider: str, request: Request) -> ChannelConnectResponse:
|
|
||||||
config = await _get_channel_connections_config(request)
|
|
||||||
channels_config = await _get_channels_config(request)
|
|
||||||
if not config.enabled:
|
|
||||||
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
|
||||||
|
|
||||||
provider_config = _provider_config(config, provider)
|
|
||||||
if provider_config.enabled and _runtime_channel_configured(provider, channels_config):
|
|
||||||
await _ensure_runtime_channel_ready_if_available(provider, channels_config)
|
|
||||||
|
|
||||||
status, unavailable_reason = _provider_status(config, channels_config, provider)
|
|
||||||
if not status["enabled"]:
|
|
||||||
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
|
||||||
if unavailable_reason:
|
|
||||||
raise HTTPException(status_code=400, detail=unavailable_reason)
|
|
||||||
if not status["configured"]:
|
|
||||||
raise HTTPException(status_code=400, detail="Channel provider is not configured")
|
|
||||||
|
|
||||||
repo = _get_repository(request, config)
|
|
||||||
code = await _create_state(
|
|
||||||
repo,
|
|
||||||
owner_user_id=_get_user_id(request),
|
|
||||||
provider=provider,
|
|
||||||
)
|
|
||||||
return ChannelConnectResponse(
|
|
||||||
provider=provider,
|
|
||||||
mode=_PROVIDER_META[provider]["auth_mode"],
|
|
||||||
url=_connect_url(config, provider, code),
|
|
||||||
code=code,
|
|
||||||
instruction=_connect_instruction(provider, code),
|
|
||||||
expires_in=_STATE_TTL_SECONDS,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{provider}/runtime-config", response_model=ChannelProviderResponse)
|
|
||||||
async def configure_channel_provider_runtime(
|
|
||||||
provider: str,
|
|
||||||
body: ChannelRuntimeConfigRequest,
|
|
||||||
request: Request,
|
|
||||||
) -> ChannelProviderResponse:
|
|
||||||
await _require_admin_user(request)
|
|
||||||
config = await _get_channel_connections_config(request)
|
|
||||||
if not config.enabled:
|
|
||||||
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
|
||||||
|
|
||||||
provider_config = _provider_config(config, provider)
|
|
||||||
if not provider_config.enabled:
|
|
||||||
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
|
||||||
|
|
||||||
channels_config = await _get_channels_config(request)
|
|
||||||
existing = channels_config.get(provider)
|
|
||||||
runtime_config = dict(existing) if isinstance(existing, dict) else {}
|
|
||||||
values = _required_runtime_values(provider, body.values, runtime_config)
|
|
||||||
runtime_config["enabled"] = True
|
|
||||||
|
|
||||||
for key in _RUNTIME_REQUIREMENTS[provider]:
|
|
||||||
runtime_config[key] = values[key]
|
|
||||||
|
|
||||||
if provider == "telegram":
|
|
||||||
# The deep-link username is persisted with the runtime channel config
|
|
||||||
# (set_provider_config below) and applied to future requests via
|
|
||||||
# apply_runtime_connection_config; never mutate the config instance
|
|
||||||
# cached by get_app_config().
|
|
||||||
runtime_config["bot_username"] = values["bot_username"]
|
|
||||||
|
|
||||||
channels_config[provider] = runtime_config
|
|
||||||
request.app.state.channels_config = channels_config
|
|
||||||
|
|
||||||
started = await _restart_runtime_channel_if_available(provider, runtime_config)
|
|
||||||
if started is False:
|
|
||||||
display_name = _PROVIDER_META[provider]["display_name"]
|
|
||||||
raise HTTPException(status_code=400, detail=f"Failed to start {display_name} channel. Check the values and try again.")
|
|
||||||
|
|
||||||
store = await _get_runtime_config_store(request)
|
|
||||||
await asyncio.to_thread(store.set_provider_config, provider, runtime_config)
|
|
||||||
|
|
||||||
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
|
|
||||||
@@ -1,10 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request, status
|
from fastapi import APIRouter, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from deerflow.config.extensions_config import ExtensionsConfig, get_extensions_config, reload_extensions_config
|
from deerflow.config.extensions_config import ExtensionsConfig, get_extensions_config, reload_extensions_config
|
||||||
@@ -13,11 +12,6 @@ logger = logging.getLogger(__name__)
|
|||||||
router = APIRouter(prefix="/api", tags=["mcp"])
|
router = APIRouter(prefix="/api", tags=["mcp"])
|
||||||
|
|
||||||
|
|
||||||
_MCP_STDIO_COMMAND_ALLOWLIST_ENV = "DEER_FLOW_MCP_STDIO_COMMAND_ALLOWLIST"
|
|
||||||
_DEFAULT_MCP_STDIO_COMMAND_ALLOWLIST = frozenset({"npx", "uvx"})
|
|
||||||
_SHELL_METACHARS = frozenset(";|&`$<>\n\r")
|
|
||||||
|
|
||||||
|
|
||||||
class McpOAuthConfigResponse(BaseModel):
|
class McpOAuthConfigResponse(BaseModel):
|
||||||
"""OAuth configuration for an MCP server."""
|
"""OAuth configuration for an MCP server."""
|
||||||
|
|
||||||
@@ -72,78 +66,6 @@ class McpConfigUpdateRequest(BaseModel):
|
|||||||
_MASKED_VALUE = "***"
|
_MASKED_VALUE = "***"
|
||||||
|
|
||||||
|
|
||||||
async def _require_admin_user(request: Request) -> None:
|
|
||||||
"""Require the authenticated caller to be an admin user.
|
|
||||||
|
|
||||||
``AuthMiddleware`` normally stamps ``request.state.user`` before the
|
|
||||||
request reaches this router. Falling back to the strict dependency keeps
|
|
||||||
this route safe even in tests or alternative ASGI compositions that mount
|
|
||||||
the router without the global middleware.
|
|
||||||
"""
|
|
||||||
user = getattr(request.state, "user", None)
|
|
||||||
if user is None:
|
|
||||||
from app.gateway.deps import get_current_user_from_request
|
|
||||||
|
|
||||||
user = await get_current_user_from_request(request)
|
|
||||||
|
|
||||||
if getattr(user, "system_role", None) != "admin":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Admin privileges required to manage MCP configuration.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _allowed_stdio_commands() -> set[str]:
|
|
||||||
"""Return executable names allowed for API-managed stdio MCP servers."""
|
|
||||||
raw = os.environ.get(_MCP_STDIO_COMMAND_ALLOWLIST_ENV)
|
|
||||||
base = set(_DEFAULT_MCP_STDIO_COMMAND_ALLOWLIST)
|
|
||||||
if raw is None:
|
|
||||||
return base
|
|
||||||
extra = {item.strip() for item in raw.split(",") if item.strip()}
|
|
||||||
return base | extra
|
|
||||||
|
|
||||||
|
|
||||||
def _stdio_command_name(command: str | None, *, server_name: str) -> str:
|
|
||||||
"""Normalize and validate a stdio command field from the API boundary."""
|
|
||||||
if command is None or not command.strip():
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"MCP server '{server_name}' with stdio transport requires a command.",
|
|
||||||
)
|
|
||||||
|
|
||||||
stripped = command.strip()
|
|
||||||
has_path_separator = "/" in stripped or "\\" in stripped
|
|
||||||
if stripped != command or has_path_separator or any(ch.isspace() for ch in stripped) or any(ch in stripped for ch in _SHELL_METACHARS):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=(f"MCP server '{server_name}' command must be a single executable name; put parameters in args instead."),
|
|
||||||
)
|
|
||||||
|
|
||||||
return stripped
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_mcp_update_request(request: McpConfigUpdateRequest) -> None:
|
|
||||||
"""Validate API-submitted MCP config before it is persisted.
|
|
||||||
|
|
||||||
Local config files can still express arbitrary advanced setups, but the
|
|
||||||
HTTP API is an untrusted boundary. Restricting stdio commands here reduces
|
|
||||||
the blast radius of a compromised authenticated browser session.
|
|
||||||
"""
|
|
||||||
allowed_commands = _allowed_stdio_commands()
|
|
||||||
for name, server in request.mcp_servers.items():
|
|
||||||
transport_type = (server.type or "stdio").lower()
|
|
||||||
if transport_type != "stdio":
|
|
||||||
continue
|
|
||||||
|
|
||||||
command_name = _stdio_command_name(server.command, server_name=name)
|
|
||||||
if command_name not in allowed_commands:
|
|
||||||
allowed = ", ".join(sorted(allowed_commands)) or "<none>"
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=(f"MCP server '{name}' uses disallowed stdio command '{command_name}'. Allowed commands: {allowed}. Configure {_MCP_STDIO_COMMAND_ALLOWLIST_ENV} to extend this list."),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _mask_server_config(server: McpServerConfigResponse) -> McpServerConfigResponse:
|
def _mask_server_config(server: McpServerConfigResponse) -> McpServerConfigResponse:
|
||||||
"""Return a copy of server config with sensitive fields masked.
|
"""Return a copy of server config with sensitive fields masked.
|
||||||
|
|
||||||
@@ -240,7 +162,7 @@ def _merge_preserving_secrets(
|
|||||||
summary="Get MCP Configuration",
|
summary="Get MCP Configuration",
|
||||||
description="Retrieve the current Model Context Protocol (MCP) server configurations.",
|
description="Retrieve the current Model Context Protocol (MCP) server configurations.",
|
||||||
)
|
)
|
||||||
async def get_mcp_configuration(request: Request) -> McpConfigResponse:
|
async def get_mcp_configuration() -> McpConfigResponse:
|
||||||
"""Get the current MCP configuration.
|
"""Get the current MCP configuration.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -261,8 +183,6 @@ async def get_mcp_configuration(request: Request) -> McpConfigResponse:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
await _require_admin_user(request)
|
|
||||||
|
|
||||||
config = get_extensions_config()
|
config = get_extensions_config()
|
||||||
|
|
||||||
servers = {name: _mask_server_config(McpServerConfigResponse(**server.model_dump())) for name, server in config.mcp_servers.items()}
|
servers = {name: _mask_server_config(McpServerConfigResponse(**server.model_dump())) for name, server in config.mcp_servers.items()}
|
||||||
@@ -275,7 +195,7 @@ async def get_mcp_configuration(request: Request) -> McpConfigResponse:
|
|||||||
summary="Update MCP Configuration",
|
summary="Update MCP Configuration",
|
||||||
description="Update Model Context Protocol (MCP) server configurations and save to file.",
|
description="Update Model Context Protocol (MCP) server configurations and save to file.",
|
||||||
)
|
)
|
||||||
async def update_mcp_configuration(request: Request, body: McpConfigUpdateRequest) -> McpConfigResponse:
|
async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfigResponse:
|
||||||
"""Update the MCP configuration.
|
"""Update the MCP configuration.
|
||||||
|
|
||||||
This will:
|
This will:
|
||||||
@@ -308,9 +228,6 @@ async def update_mcp_configuration(request: Request, body: McpConfigUpdateReques
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
await _require_admin_user(request)
|
|
||||||
_validate_mcp_update_request(body)
|
|
||||||
|
|
||||||
# Get the current config path (or determine where to save it)
|
# Get the current config path (or determine where to save it)
|
||||||
config_path = ExtensionsConfig.resolve_config_path()
|
config_path = ExtensionsConfig.resolve_config_path()
|
||||||
|
|
||||||
@@ -338,7 +255,7 @@ async def update_mcp_configuration(request: Request, body: McpConfigUpdateReques
|
|||||||
|
|
||||||
# Merge incoming server configs with raw on-disk secrets
|
# Merge incoming server configs with raw on-disk secrets
|
||||||
merged_servers: dict[str, McpServerConfigResponse] = {}
|
merged_servers: dict[str, McpServerConfigResponse] = {}
|
||||||
for name, incoming in body.mcp_servers.items():
|
for name, incoming in request.mcp_servers.items():
|
||||||
raw_server = raw_servers.get(name)
|
raw_server = raw_servers.get(name)
|
||||||
if raw_server is not None:
|
if raw_server is not None:
|
||||||
merged_servers[name] = _merge_preserving_secrets(
|
merged_servers[name] = _merge_preserving_secrets(
|
||||||
@@ -366,8 +283,6 @@ async def update_mcp_configuration(request: Request, body: McpConfigUpdateReques
|
|||||||
servers = {name: _mask_server_config(McpServerConfigResponse(**server.model_dump())) for name, server in reloaded_config.mcp_servers.items()}
|
servers = {name: _mask_server_config(McpServerConfigResponse(**server.model_dump())) for name, server in reloaded_config.mcp_servers.items()}
|
||||||
return McpConfigResponse(mcp_servers=servers)
|
return McpConfigResponse(mcp_servers=servers)
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to update MCP configuration: {e}", exc_info=True)
|
logger.error(f"Failed to update MCP configuration: {e}", exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to update MCP configuration: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Failed to update MCP configuration: {str(e)}")
|
||||||
|
|||||||
@@ -98,7 +98,6 @@ class MemoryConfigResponse(BaseModel):
|
|||||||
fact_confidence_threshold: float = Field(..., description="Minimum confidence threshold for facts")
|
fact_confidence_threshold: float = Field(..., description="Minimum confidence threshold for facts")
|
||||||
injection_enabled: bool = Field(..., description="Whether memory injection is enabled")
|
injection_enabled: bool = Field(..., description="Whether memory injection is enabled")
|
||||||
max_injection_tokens: int = Field(..., description="Maximum tokens for memory injection")
|
max_injection_tokens: int = Field(..., description="Maximum tokens for memory injection")
|
||||||
token_counting: str = Field(..., description="Token counting strategy for memory injection ('tiktoken' or 'char')")
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryStatusResponse(BaseModel):
|
class MemoryStatusResponse(BaseModel):
|
||||||
@@ -311,8 +310,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
|||||||
"max_facts": 100,
|
"max_facts": 100,
|
||||||
"fact_confidence_threshold": 0.7,
|
"fact_confidence_threshold": 0.7,
|
||||||
"injection_enabled": true,
|
"injection_enabled": true,
|
||||||
"max_injection_tokens": 2000,
|
"max_injection_tokens": 2000
|
||||||
"token_counting": "tiktoken"
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
@@ -325,7 +323,6 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
|||||||
fact_confidence_threshold=config.fact_confidence_threshold,
|
fact_confidence_threshold=config.fact_confidence_threshold,
|
||||||
injection_enabled=config.injection_enabled,
|
injection_enabled=config.injection_enabled,
|
||||||
max_injection_tokens=config.max_injection_tokens,
|
max_injection_tokens=config.max_injection_tokens,
|
||||||
token_counting=config.token_counting,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -354,7 +351,6 @@ async def get_memory_status() -> MemoryStatusResponse:
|
|||||||
fact_confidence_threshold=config.fact_confidence_threshold,
|
fact_confidence_threshold=config.fact_confidence_threshold,
|
||||||
injection_enabled=config.injection_enabled,
|
injection_enabled=config.injection_enabled,
|
||||||
max_injection_tokens=config.max_injection_tokens,
|
max_injection_tokens=config.max_injection_tokens,
|
||||||
token_counting=config.token_counting,
|
|
||||||
),
|
),
|
||||||
data=MemoryResponse(**memory_data),
|
data=MemoryResponse(**memory_data),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_
|
|||||||
from app.gateway.pagination import trim_run_message_page
|
from app.gateway.pagination import trim_run_message_page
|
||||||
from app.gateway.routers.thread_runs import RunCreateRequest
|
from app.gateway.routers.thread_runs import RunCreateRequest
|
||||||
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
|
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
|
||||||
from deerflow.runtime import serialize_channel_values_for_api
|
from deerflow.runtime import serialize_channel_values
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/api/runs", tags=["runs"])
|
router = APIRouter(prefix="/api/runs", tags=["runs"])
|
||||||
@@ -82,7 +82,7 @@ async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
|
|||||||
if checkpoint_tuple is not None:
|
if checkpoint_tuple is not None:
|
||||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||||
channel_values = checkpoint.get("channel_values", {})
|
channel_values = checkpoint.get("channel_values", {})
|
||||||
return serialize_channel_values_for_api(channel_values)
|
return serialize_channel_values(channel_values)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Request
|
from fastapi import APIRouter, Depends, Request
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
@@ -31,31 +30,6 @@ class SuggestionsResponse(BaseModel):
|
|||||||
suggestions: list[str] = Field(default_factory=list, description="Suggested follow-up questions")
|
suggestions: list[str] = Field(default_factory=list, description="Suggested follow-up questions")
|
||||||
|
|
||||||
|
|
||||||
# Matches a complete <think>...</think> block (case-insensitive, spans newlines).
|
|
||||||
_THINK_BLOCK_RE = re.compile(r"<think\b[^>]*>.*?</think\s*>", re.IGNORECASE | re.DOTALL)
|
|
||||||
# Matches a dangling, unclosed <think> (model truncated at max_tokens mid-thought).
|
|
||||||
_OPEN_THINK_RE = re.compile(r"<think\b[^>]*>", re.IGNORECASE)
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_think_blocks(text: str) -> str:
|
|
||||||
"""Remove reasoning-model ``<think>...</think>`` blocks from the response.
|
|
||||||
|
|
||||||
Reasoning models such as MiniMax-M3 inline their chain-of-thought into the
|
|
||||||
message ``content`` wrapped in ``<think>...</think>`` (``reasoning_split``
|
|
||||||
defaults to false), rather than exposing a separate ``reasoning_content``
|
|
||||||
field. The thinking text frequently contains ``[`` / ``]`` characters, which
|
|
||||||
corrupted the downstream ``find('[')`` / ``rfind(']')`` JSON extraction and
|
|
||||||
produced empty suggestions. We strip the reasoning before parsing so only
|
|
||||||
the actual answer remains.
|
|
||||||
"""
|
|
||||||
text = _THINK_BLOCK_RE.sub("", text)
|
|
||||||
# Drop any unclosed <think> (and everything after it) left by truncation.
|
|
||||||
open_match = _OPEN_THINK_RE.search(text)
|
|
||||||
if open_match:
|
|
||||||
text = text[: open_match.start()]
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_markdown_code_fence(text: str) -> str:
|
def _strip_markdown_code_fence(text: str) -> str:
|
||||||
stripped = text.strip()
|
stripped = text.strip()
|
||||||
if not stripped.startswith("```"):
|
if not stripped.startswith("```"):
|
||||||
@@ -67,8 +41,7 @@ def _strip_markdown_code_fence(text: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _parse_json_string_list(text: str) -> list[str] | None:
|
def _parse_json_string_list(text: str) -> list[str] | None:
|
||||||
candidate = _strip_think_blocks(text)
|
candidate = _strip_markdown_code_fence(text)
|
||||||
candidate = _strip_markdown_code_fence(candidate)
|
|
||||||
start = candidate.find("[")
|
start = candidate.find("[")
|
||||||
end = candidate.rfind("]")
|
end = candidate.rfind("]")
|
||||||
if start == -1 or end == -1 or end <= start:
|
if start == -1 or end == -1 or end <= start:
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from app.gateway.authz import require_permission
|
|||||||
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||||
from app.gateway.pagination import trim_run_message_page
|
from app.gateway.pagination import trim_run_message_page
|
||||||
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
|
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
|
||||||
from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values_for_api
|
from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/api/threads", tags=["runs"])
|
router = APIRouter(prefix="/api/threads", tags=["runs"])
|
||||||
@@ -192,7 +192,7 @@ async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) ->
|
|||||||
if checkpoint_tuple is not None:
|
if checkpoint_tuple is not None:
|
||||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||||
channel_values = checkpoint.get("channel_values", {})
|
channel_values = checkpoint.get("channel_values", {})
|
||||||
return serialize_channel_values_for_api(channel_values)
|
return serialize_channel_values(channel_values)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||||
|
|
||||||
|
|||||||
@@ -17,15 +17,14 @@ import uuid
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
from langgraph.checkpoint.base import empty_checkpoint, uuid6
|
from langgraph.checkpoint.base import empty_checkpoint
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.gateway.authz import require_permission
|
||||||
from app.gateway.deps import get_checkpointer
|
from app.gateway.deps import get_checkpointer
|
||||||
from app.gateway.internal_auth import get_trusted_internal_owner_user_id
|
|
||||||
from app.gateway.utils import sanitize_log_param
|
from app.gateway.utils import sanitize_log_param
|
||||||
from deerflow.config.paths import Paths, get_paths
|
from deerflow.config.paths import Paths, get_paths
|
||||||
from deerflow.runtime import serialize_channel_values_for_api
|
from deerflow.runtime import serialize_channel_values
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
from deerflow.utils.time import coerce_iso, now_iso
|
from deerflow.utils.time import coerce_iso, now_iso
|
||||||
|
|
||||||
@@ -258,19 +257,11 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
thread_store = get_thread_store(request)
|
thread_store = get_thread_store(request)
|
||||||
thread_id = body.thread_id or str(uuid.uuid4())
|
thread_id = body.thread_id or str(uuid.uuid4())
|
||||||
now = now_iso()
|
now = now_iso()
|
||||||
thread_owner_user_id = get_trusted_internal_owner_user_id(request)
|
|
||||||
thread_owner_kwargs = {"user_id": thread_owner_user_id} if thread_owner_user_id else {}
|
|
||||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
# ``body.metadata`` is already stripped of server-reserved keys by
|
||||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
||||||
|
|
||||||
# Idempotency: return existing record when already present
|
# Idempotency: return existing record when already present
|
||||||
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
|
existing_record = await thread_store.get(thread_id)
|
||||||
if existing_record is None and thread_owner_user_id:
|
|
||||||
unscoped_record = await thread_store.get(thread_id, user_id=None)
|
|
||||||
if unscoped_record is not None:
|
|
||||||
if unscoped_record.get("user_id") != thread_owner_user_id:
|
|
||||||
await thread_store.update_owner(thread_id, thread_owner_user_id, user_id=None)
|
|
||||||
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
|
|
||||||
if existing_record is not None:
|
if existing_record is not None:
|
||||||
return ThreadResponse(
|
return ThreadResponse(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
@@ -285,7 +276,6 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
await thread_store.create(
|
await thread_store.create(
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=getattr(body, "assistant_id", None),
|
assistant_id=getattr(body, "assistant_id", None),
|
||||||
**thread_owner_kwargs,
|
|
||||||
metadata=body.metadata,
|
metadata=body.metadata,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -437,7 +427,7 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
|||||||
created_at=coerce_iso(record.get("created_at", "")),
|
created_at=coerce_iso(record.get("created_at", "")),
|
||||||
updated_at=coerce_iso(record.get("updated_at", "")),
|
updated_at=coerce_iso(record.get("updated_at", "")),
|
||||||
metadata=record.get("metadata", {}),
|
metadata=record.get("metadata", {}),
|
||||||
values=serialize_channel_values_for_api(channel_values),
|
values=serialize_channel_values(channel_values),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -480,7 +470,7 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
|
|||||||
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
|
||||||
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
|
||||||
|
|
||||||
values = serialize_channel_values_for_api(channel_values)
|
values = serialize_channel_values(channel_values)
|
||||||
|
|
||||||
return ThreadStateResponse(
|
return ThreadStateResponse(
|
||||||
values=values,
|
values=values,
|
||||||
@@ -546,21 +536,9 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
|||||||
metadata["step"] = metadata.get("step", 0) + 1
|
metadata["step"] = metadata.get("step", 0) + 1
|
||||||
metadata["writes"] = {body.as_node: body.values}
|
metadata["writes"] = {body.as_node: body.values}
|
||||||
|
|
||||||
# Assign a new checkpoint ID so aput performs an INSERT rather than an
|
|
||||||
# in-place REPLACE of the existing row. Use uuid6 (time-ordered) rather
|
|
||||||
# than uuid4 (random) so the new ID is always lexicographically greater
|
|
||||||
# than the previous one — LangGraph's checkpointers determine the "latest"
|
|
||||||
# checkpoint by max(checkpoint_ids) string order, matching the uuid6 epoch.
|
|
||||||
checkpoint["id"] = str(uuid6())
|
|
||||||
|
|
||||||
# aput requires checkpoint_ns in the config — use the same config used for the
|
# aput requires checkpoint_ns in the config — use the same config used for the
|
||||||
# read (which always includes checkpoint_ns=""). The fresh checkpoint ID is
|
# read (which always includes checkpoint_ns=""). Do NOT include checkpoint_id
|
||||||
# assigned above via checkpoint["id"]; keep checkpoint_id out of the config so
|
# so that aput generates a fresh checkpoint ID for the new snapshot.
|
||||||
# the write is keyed by the new checkpoint payload rather than the prior read.
|
|
||||||
# All supported savers (InMemorySaver, AsyncSqliteSaver, AsyncPostgresSaver)
|
|
||||||
# persist and echo back checkpoint["id"] verbatim — none mint their own — so
|
|
||||||
# the new_config below carries the uuid6 we assigned here. (Regression-locked
|
|
||||||
# by test_update_thread_state_inserts_new_checkpoint_each_call.)
|
|
||||||
write_config: dict[str, Any] = {
|
write_config: dict[str, Any] = {
|
||||||
"configurable": {
|
"configurable": {
|
||||||
"thread_id": thread_id,
|
"thread_id": thread_id,
|
||||||
@@ -579,7 +557,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
|||||||
|
|
||||||
# Sync title changes through the ThreadMetaStore abstraction so /threads/search
|
# Sync title changes through the ThreadMetaStore abstraction so /threads/search
|
||||||
# reflects them immediately in both sqlite and memory backends.
|
# reflects them immediately in both sqlite and memory backends.
|
||||||
if thread_store and body.values and "title" in body.values:
|
if body.values and "title" in body.values:
|
||||||
new_title = body.values["title"]
|
new_title = body.values["title"]
|
||||||
if new_title: # Skip empty strings and None
|
if new_title: # Skip empty strings and None
|
||||||
try:
|
try:
|
||||||
@@ -588,7 +566,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
|||||||
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
return ThreadStateResponse(
|
return ThreadStateResponse(
|
||||||
values=serialize_channel_values_for_api(channel_values),
|
values=serialize_channel_values(channel_values),
|
||||||
next=[],
|
next=[],
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
checkpoint_id=new_checkpoint_id,
|
checkpoint_id=new_checkpoint_id,
|
||||||
@@ -640,7 +618,7 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
|||||||
if is_latest_checkpoint:
|
if is_latest_checkpoint:
|
||||||
messages = channel_values.get("messages")
|
messages = channel_values.get("messages")
|
||||||
if messages:
|
if messages:
|
||||||
values["messages"] = serialize_channel_values_for_api({"messages": messages}).get("messages", [])
|
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
|
||||||
is_latest_checkpoint = False
|
is_latest_checkpoint = False
|
||||||
|
|
||||||
# Derive next tasks
|
# Derive next tasks
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from types import SimpleNamespace
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
@@ -20,7 +19,7 @@ from langchain_core.messages import BaseMessage
|
|||||||
from langchain_core.messages.utils import convert_to_messages
|
from langchain_core.messages.utils import convert_to_messages
|
||||||
|
|
||||||
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
|
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
|
||||||
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE, get_trusted_internal_owner_user_id
|
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
|
||||||
from app.gateway.utils import sanitize_log_param
|
from app.gateway.utils import sanitize_log_param
|
||||||
from deerflow.config.app_config import get_app_config
|
from deerflow.config.app_config import get_app_config
|
||||||
from deerflow.runtime import (
|
from deerflow.runtime import (
|
||||||
@@ -36,7 +35,6 @@ from deerflow.runtime import (
|
|||||||
run_agent,
|
run_agent,
|
||||||
)
|
)
|
||||||
from deerflow.runtime.runs.naming import resolve_root_run_name
|
from deerflow.runtime.runs.naming import resolve_root_run_name
|
||||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -317,32 +315,6 @@ async def start_run(
|
|||||||
detail=f"Model {model_name!r} is not in the configured model allowlist",
|
detail=f"Model {model_name!r} is not in the configured model allowlist",
|
||||||
)
|
)
|
||||||
|
|
||||||
owner_user_id = get_trusted_internal_owner_user_id(request)
|
|
||||||
# Stateless run endpoints carry thread_id in the request *body*, so the
|
|
||||||
# @require_permission(owner_check=True) decorator -- which resolves ownership
|
|
||||||
# from the path param -- cannot protect them. Enforce thread ownership here,
|
|
||||||
# before any run is created, so one user cannot start runs on (or read /wait
|
|
||||||
# checkpoint state from) another user's thread. Missing rows (auto-created
|
|
||||||
# temp threads) and NULL-owner rows (shared / pre-auth data) stay accessible
|
|
||||||
# via check_access; only a thread already owned by another user is rejected
|
|
||||||
# with 404, matching thread_runs.py's anti-enumeration behaviour. Internal
|
|
||||||
# channel runs act on behalf of the connection owner carried in
|
|
||||||
# X-DeerFlow-Owner-User-Id, so they are scoped to that owner instead of
|
|
||||||
# bypassing the check -- a leaked internal token must not grant cross-user
|
|
||||||
# thread access.
|
|
||||||
user = getattr(request.state, "user", None)
|
|
||||||
if user is not None:
|
|
||||||
allowed = await run_ctx.thread_store.check_access(thread_id, str(user.id))
|
|
||||||
if not allowed and owner_user_id and getattr(user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
|
|
||||||
# Channel workers may also act for the connection owner named in
|
|
||||||
# the trusted header (e.g. claiming a legacy default-owned channel
|
|
||||||
# thread for its real owner).
|
|
||||||
allowed = await run_ctx.thread_store.check_access(thread_id, owner_user_id)
|
|
||||||
if not allowed:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
|
||||||
|
|
||||||
owner_context_token = set_current_user(SimpleNamespace(id=owner_user_id)) if owner_user_id else None
|
|
||||||
try:
|
|
||||||
try:
|
try:
|
||||||
record = await run_mgr.create_or_reject(
|
record = await run_mgr.create_or_reject(
|
||||||
thread_id,
|
thread_id,
|
||||||
@@ -352,7 +324,6 @@ async def start_run(
|
|||||||
kwargs={"input": body.input, "config": body.config},
|
kwargs={"input": body.input, "config": body.config},
|
||||||
multitask_strategy=body.multitask_strategy,
|
multitask_strategy=body.multitask_strategy,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
user_id=owner_user_id,
|
|
||||||
)
|
)
|
||||||
except ConflictError as exc:
|
except ConflictError as exc:
|
||||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||||
@@ -364,12 +335,6 @@ async def start_run(
|
|||||||
# (e.g. stateless runs).
|
# (e.g. stateless runs).
|
||||||
try:
|
try:
|
||||||
existing = await run_ctx.thread_store.get(thread_id)
|
existing = await run_ctx.thread_store.get(thread_id)
|
||||||
if existing is None and owner_user_id:
|
|
||||||
unscoped_existing = await run_ctx.thread_store.get(thread_id, user_id=None)
|
|
||||||
if unscoped_existing is not None:
|
|
||||||
if unscoped_existing.get("user_id") != owner_user_id:
|
|
||||||
await run_ctx.thread_store.update_owner(thread_id, owner_user_id, user_id=None)
|
|
||||||
existing = await run_ctx.thread_store.get(thread_id)
|
|
||||||
if existing is None:
|
if existing is None:
|
||||||
await run_ctx.thread_store.create(
|
await run_ctx.thread_store.create(
|
||||||
thread_id,
|
thread_id,
|
||||||
@@ -416,9 +381,6 @@ async def start_run(
|
|||||||
# after the run completes.
|
# after the run completes.
|
||||||
|
|
||||||
return record
|
return record
|
||||||
finally:
|
|
||||||
if owner_context_token is not None:
|
|
||||||
reset_current_user(owner_context_token)
|
|
||||||
|
|
||||||
|
|
||||||
async def sse_consumer(
|
async def sse_consumer(
|
||||||
|
|||||||
+4
-22
@@ -228,13 +228,10 @@ Get current MCP server configurations.
|
|||||||
GET /api/mcp/config
|
GET /api/mcp/config
|
||||||
```
|
```
|
||||||
|
|
||||||
Requires an authenticated admin session. Sensitive env/header/OAuth secret
|
|
||||||
values are masked in the response.
|
|
||||||
|
|
||||||
**Response:**
|
**Response:**
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"mcp_servers": {
|
"mcpServers": {
|
||||||
"github": {
|
"github": {
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
"type": "stdio",
|
"type": "stdio",
|
||||||
@@ -258,15 +255,10 @@ PUT /api/mcp/config
|
|||||||
Content-Type: application/json
|
Content-Type: application/json
|
||||||
```
|
```
|
||||||
|
|
||||||
Requires an authenticated admin session. API-managed `stdio` MCP servers may
|
|
||||||
only use allowed executable names for `command` (default: `npx`, `uvx`). Set
|
|
||||||
`DEER_FLOW_MCP_STDIO_COMMAND_ALLOWLIST` to a comma-separated list when a
|
|
||||||
deployment needs additional trusted launchers.
|
|
||||||
|
|
||||||
**Request Body:**
|
**Request Body:**
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"mcp_servers": {
|
"mcpServers": {
|
||||||
"github": {
|
"github": {
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
"type": "stdio",
|
"type": "stdio",
|
||||||
@@ -284,18 +276,8 @@ deployment needs additional trusted launchers.
|
|||||||
**Response:**
|
**Response:**
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"mcp_servers": {
|
"success": true,
|
||||||
"github": {
|
"message": "MCP configuration updated"
|
||||||
"enabled": true,
|
|
||||||
"type": "stdio",
|
|
||||||
"command": "npx",
|
|
||||||
"args": ["-y", "@modelcontextprotocol/server-github"],
|
|
||||||
"env": {
|
|
||||||
"GITHUB_TOKEN": "***"
|
|
||||||
},
|
|
||||||
"description": "GitHub operations"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -67,11 +67,6 @@ The normal workflow is:
|
|||||||
3. Add or update a focused runtime anchor in `backend/tests/blocking_io/`.
|
3. Add or update a focused runtime anchor in `backend/tests/blocking_io/`.
|
||||||
4. Let CI prevent that path from regressing.
|
4. Let CI prevent that path from regressing.
|
||||||
|
|
||||||
Contributors changing backend async code can run the `blocking-io-guard` skill
|
|
||||||
(`.agent/skills/blocking-io-guard/`) to execute steps 1–3 for their own diff: it
|
|
||||||
scans the change for blocking-IO candidates, drafts or extends a runtime anchor,
|
|
||||||
and verifies the anchor fails when the blocking IO regresses.
|
|
||||||
|
|
||||||
Runtime detection has two maintenance paths.
|
Runtime detection has two maintenance paths.
|
||||||
|
|
||||||
### Add a runtime rule
|
### Add a runtime rule
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ models:
|
|||||||
base_url: https://api.minimax.io/v1
|
base_url: https://api.minimax.io/v1
|
||||||
max_tokens: 4096
|
max_tokens: 4096
|
||||||
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0]
|
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0]
|
||||||
supports_vision: false # M2.7 is text-only; M3 supports vision
|
supports_vision: true
|
||||||
|
|
||||||
- name: minimax-m2.7-highspeed
|
- name: minimax-m2.7-highspeed
|
||||||
display_name: MiniMax M2.7 Highspeed
|
display_name: MiniMax M2.7 Highspeed
|
||||||
@@ -123,7 +123,7 @@ models:
|
|||||||
base_url: https://api.minimax.io/v1
|
base_url: https://api.minimax.io/v1
|
||||||
max_tokens: 4096
|
max_tokens: 4096
|
||||||
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0]
|
temperature: 1.0 # MiniMax requires temperature in (0.0, 1.0]
|
||||||
supports_vision: false # M2.7 is text-only; M3 supports vision
|
supports_vision: true
|
||||||
- name: openrouter-gemini-2.5-flash
|
- name: openrouter-gemini-2.5-flash
|
||||||
display_name: Gemini 2.5 Flash (OpenRouter)
|
display_name: Gemini 2.5 Flash (OpenRouter)
|
||||||
use: langchain_openai:ChatOpenAI
|
use: langchain_openai:ChatOpenAI
|
||||||
|
|||||||
@@ -1,122 +0,0 @@
|
|||||||
# IM Channel Connections
|
|
||||||
|
|
||||||
DeerFlow supports user-owned IM channel bindings for Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom. The feature reuses the existing `channels.*` runtime configuration, so it works in local and private deployments with the same outbound transports already supported by DeerFlow.
|
|
||||||
|
|
||||||
No public IP, OAuth callback URL, or provider webhook is required in this implementation.
|
|
||||||
|
|
||||||
## Configuration
|
|
||||||
|
|
||||||
Configure the actual IM bots under the existing `channels` block:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
channels:
|
|
||||||
telegram:
|
|
||||||
enabled: true
|
|
||||||
bot_token: $TELEGRAM_BOT_TOKEN
|
|
||||||
|
|
||||||
slack:
|
|
||||||
enabled: true
|
|
||||||
bot_token: $SLACK_BOT_TOKEN
|
|
||||||
app_token: $SLACK_APP_TOKEN
|
|
||||||
|
|
||||||
discord:
|
|
||||||
enabled: true
|
|
||||||
bot_token: $DISCORD_BOT_TOKEN
|
|
||||||
|
|
||||||
feishu:
|
|
||||||
enabled: true
|
|
||||||
app_id: $FEISHU_APP_ID
|
|
||||||
app_secret: $FEISHU_APP_SECRET
|
|
||||||
|
|
||||||
dingtalk:
|
|
||||||
enabled: true
|
|
||||||
client_id: $DINGTALK_CLIENT_ID
|
|
||||||
client_secret: $DINGTALK_CLIENT_SECRET
|
|
||||||
|
|
||||||
wechat:
|
|
||||||
enabled: true
|
|
||||||
bot_token: $WECHAT_BOT_TOKEN
|
|
||||||
|
|
||||||
wecom:
|
|
||||||
enabled: true
|
|
||||||
bot_id: $WECOM_BOT_ID
|
|
||||||
bot_secret: $WECOM_BOT_SECRET
|
|
||||||
```
|
|
||||||
|
|
||||||
Then enable user bindings in `channel_connections`:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
channel_connections:
|
|
||||||
enabled: true
|
|
||||||
|
|
||||||
telegram:
|
|
||||||
enabled: true
|
|
||||||
bot_username: $TELEGRAM_BOT_USERNAME
|
|
||||||
|
|
||||||
slack:
|
|
||||||
enabled: true
|
|
||||||
|
|
||||||
discord:
|
|
||||||
enabled: true
|
|
||||||
|
|
||||||
feishu:
|
|
||||||
enabled: true
|
|
||||||
|
|
||||||
dingtalk:
|
|
||||||
enabled: true
|
|
||||||
|
|
||||||
wechat:
|
|
||||||
enabled: true
|
|
||||||
|
|
||||||
wecom:
|
|
||||||
enabled: true
|
|
||||||
```
|
|
||||||
|
|
||||||
`channel_connections` does not duplicate provider secrets. It only controls the browser-facing connect UI and stores per-user binding records. Telegram needs `bot_username` only so the frontend can open a deep link.
|
|
||||||
|
|
||||||
## Connect Flow
|
|
||||||
|
|
||||||
Telegram:
|
|
||||||
|
|
||||||
- The frontend creates a short one-time code.
|
|
||||||
- The Connect button opens `https://t.me/<bot_username>?start=<code>`.
|
|
||||||
- The existing Telegram long-polling worker receives `/start <code>` and binds that Telegram chat/user to the current DeerFlow user.
|
|
||||||
|
|
||||||
Slack:
|
|
||||||
|
|
||||||
- The frontend creates a short one-time code.
|
|
||||||
- The UI shows `Send /connect <code> to the DeerFlow Slack bot.`
|
|
||||||
- The existing Slack Socket Mode worker receives the message and binds the Slack user/team to the current DeerFlow user.
|
|
||||||
|
|
||||||
Discord:
|
|
||||||
|
|
||||||
- The frontend creates a short one-time code.
|
|
||||||
- The UI shows `Send /connect <code> to the DeerFlow Discord bot.`
|
|
||||||
- The existing Discord Gateway worker receives the message and binds the Discord user/guild to the current DeerFlow user.
|
|
||||||
|
|
||||||
Feishu/Lark, DingTalk, WeChat, and WeCom:
|
|
||||||
|
|
||||||
- The frontend creates a short one-time code.
|
|
||||||
- The UI shows `Send /connect <code> to the DeerFlow <Provider> bot.`
|
|
||||||
- The already-running long-connection or polling worker receives the message and binds the platform user/workspace identity to the current DeerFlow user.
|
|
||||||
|
|
||||||
Codes use 128 bits of randomness, expire after 10 minutes, and are single-use.
|
|
||||||
|
|
||||||
## Runtime Model
|
|
||||||
|
|
||||||
Connection records live in SQL tables under `deerflow.persistence.channel_connections`:
|
|
||||||
|
|
||||||
- `channel_connections`: owner user, provider identity, workspace/guild/team, status, metadata.
|
|
||||||
- `channel_oauth_states`: one-time connect codes and Telegram deep-link state.
|
|
||||||
- `channel_conversations`: connection-scoped IM conversation to DeerFlow thread mapping.
|
|
||||||
- `channel_credentials`: reserved for future provider-token flows, not used by the local/private binding flow.
|
|
||||||
|
|
||||||
Incoming messages that resolve to a connection carry `connection_id`, `owner_user_id`, and `workspace_id`. `ChannelManager` uses `owner_user_id` as the DeerFlow run user id and preserves the raw platform user id as `channel_user_id`.
|
|
||||||
|
|
||||||
## Security Notes
|
|
||||||
|
|
||||||
- Browser APIs remain authenticated and CSRF-protected.
|
|
||||||
- Connect codes are 128-bit random, short-lived, and single-use.
|
|
||||||
- Provider bot tokens remain in `channels.*` and are never returned to the browser.
|
|
||||||
- Stored per-connection credentials are encrypted. If stored credential material cannot be decrypted, DeerFlow treats it as unavailable instead of using corrupt secrets.
|
|
||||||
- This implementation does not add public provider callback or webhook routes.
|
|
||||||
@@ -31,8 +31,7 @@ Current injection format:
|
|||||||
|
|
||||||
Token counting:
|
Token counting:
|
||||||
- Uses `tiktoken` (`cl100k_base`) when available
|
- Uses `tiktoken` (`cl100k_base`) when available
|
||||||
- Falls back to a network-free CJK-aware character estimate if tokenizer import or encoding load fails
|
- Falls back to `len(text) // 4` if tokenizer import fails
|
||||||
(CJK characters count as ~2 chars/token, other characters as ~4 chars/token)
|
|
||||||
|
|
||||||
## Known Gap
|
## Known Gap
|
||||||
|
|
||||||
|
|||||||
@@ -1,120 +0,0 @@
|
|||||||
# Record/Replay E2E — front-back contract verification
|
|
||||||
|
|
||||||
Deterministic, **key-free** end-to-end checks that a backend change can't
|
|
||||||
silently break the frontend (and vice-versa). Two complementary layers, fed by a
|
|
||||||
single recording.
|
|
||||||
|
|
||||||
## Why
|
|
||||||
|
|
||||||
The mock-based frontend e2e hand-writes the backend's JSON/SSE, so a backend
|
|
||||||
schema or SSE change passes green ("fake green"). These layers replay a recorded
|
|
||||||
**real** run against the **real** backend (and, for Layer 2, the real frontend),
|
|
||||||
so contract drift turns the build red instead.
|
|
||||||
|
|
||||||
## The two layers
|
|
||||||
|
|
||||||
- **Layer 1 — backend golden** (`tests/test_replay_golden.py`): replays a fixture
|
|
||||||
through the real FastAPI gateway with `ReplayChatModel` and asserts the streamed
|
|
||||||
SSE event sequence equals a committed golden. Fast, no browser. Guards protocol
|
|
||||||
*shape*.
|
|
||||||
- **Layer 2 — full-stack render** (`frontend/tests/e2e-real-backend/`): real
|
|
||||||
Next.js + real gateway (replay model) + Chromium; asserts the replayed
|
|
||||||
auto-title and a follow-up suggestion render in the browser. Guards semantic
|
|
||||||
*render*. (Complementary to Layer 1 — neither subsumes the other.)
|
|
||||||
|
|
||||||
Layer 2 also hosts **cross-stack contract scenarios** — the dangerous class
|
|
||||||
where a backend change silently breaks a frontend assumption and *both sides'
|
|
||||||
unit tests stay green*. See below.
|
|
||||||
|
|
||||||
## Cross-stack scenario: multi-run render order (`multi-run-order.spec.ts`)
|
|
||||||
|
|
||||||
Regression guard for issue **#3352** (after context compression, refreshing a
|
|
||||||
thread rendered history out of order). Root cause was a front-back desync:
|
|
||||||
backend `RunManager.list_by_thread` returns runs **newest-first** (PR #2932),
|
|
||||||
while the frontend (`core/threads/hooks.ts`) iterated runs and **prepended** each
|
|
||||||
loaded page — inverting chronological order once the checkpoint no longer held
|
|
||||||
the older messages. The backend ordering test was green throughout, and the
|
|
||||||
frontend regression unit test hardcodes "backend returns newest-first" in a mock,
|
|
||||||
so only a *real frontend against a real backend* catches the desync.
|
|
||||||
|
|
||||||
This scenario does **not** record a conversation. It uses a **test-only seeder**
|
|
||||||
(`tests/seed_runs_router.py`, mounted on the replay gateway only when
|
|
||||||
`DEERFLOW_ENABLE_TEST_SEED=1`) to stand up a thread with ≥2 runs and per-run
|
|
||||||
message events — and deliberately **no checkpoint**, which is the #3352
|
|
||||||
precondition: it forces the frontend's per-run reload path to be the sole source
|
|
||||||
of truth so the ordering bug becomes observable. The seeder writes through the
|
|
||||||
gateway's own run/event stores using the request's auth context, so the real
|
|
||||||
`list_by_thread` → `/runs/{id}/messages` → prepend path runs live. Reverting the
|
|
||||||
#3354 frontend fix turns this spec red.
|
|
||||||
|
|
||||||
## How replay works
|
|
||||||
|
|
||||||
`tests/replay_provider.py::ReplayChatModel` returns recorded assistant turns keyed
|
|
||||||
by a **normalized hash of the model caller + conversation**. The conversation is
|
|
||||||
human / ai / tool messages — role, text, tool-call name+args; with
|
|
||||||
`<system-reminder>`, dates, UUIDs, tmp paths stripped. The caller is the stable
|
|
||||||
source of the model call (`lead_agent`, `middleware:title`, `suggest_agent`,
|
|
||||||
`subagent:*`, etc.). A miss raises loudly rather than passing silently.
|
|
||||||
|
|
||||||
**The system prompt is excluded from the match key.** The lead-agent system
|
|
||||||
prompt is a living, frequently-edited implementation detail — its wording changes
|
|
||||||
across PRs (e.g. #3195 added a "File Editing Workflow" section). Hashing it would
|
|
||||||
make every fixture go stale and red-fail unrelated PRs the moment anyone edits the
|
|
||||||
prompt. The conversation flow (user input → tool calls → results → answer) is the
|
|
||||||
stable contract that identifies a recorded turn. The caller still stays in the
|
|
||||||
key so two different model users with identical conversation text do not compete
|
|
||||||
for the same replay bucket. (This mirrors how open-design's mock picker keys on
|
|
||||||
the user prompt, not the system internals.) Combined with pinning skills +
|
|
||||||
extensions empty and disabling memory/summarization
|
|
||||||
(`tests/_replay_fixture.py::build_config_yaml`), a fixture replays the same across
|
|
||||||
machines, days, prompt edits, and CI. Replaying needs **no API key**.
|
|
||||||
|
|
||||||
A swallowed hash-miss keeps the SSE *event shapes* identical (the gateway wraps it
|
|
||||||
into a normal assistant error message), so the Layer-1 golden can't catch a miss
|
|
||||||
by shape alone — it inspects `replay_provider.replay_misses()` and fails loud
|
|
||||||
instead. Layer-2 already fails on a miss (the recorded turns never render).
|
|
||||||
|
|
||||||
## Record a new scenario (needs a real key — dev machine only)
|
|
||||||
|
|
||||||
Recording drives the **real frontend** so captured inputs match exactly what the
|
|
||||||
browser sends; fixtures contain no API key.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 1. drive the real frontend against a real-model gateway, capturing model calls
|
|
||||||
OPENAI_API_KEY=... OPENAI_API_BASE=<openai-compatible-endpoint>/v1 \
|
|
||||||
DEERFLOW_RECORD_OUT=/tmp/rec/turns.jsonl RECORD_MODEL=<model> \
|
|
||||||
bash -c 'cd frontend && pnpm exec playwright test -c playwright.record.config.ts'
|
|
||||||
|
|
||||||
# 2. stitch the capture into a fixture
|
|
||||||
cd backend && uv run python scripts/build_fixture_from_jsonl.py \
|
|
||||||
--jsonl /tmp/rec/turns.jsonl --meta /tmp/rec/turns.jsonl.meta.json \
|
|
||||||
--out tests/fixtures/replay/<scenario>.<mode>.json --model <model>
|
|
||||||
|
|
||||||
# 3. regenerate the committed golden
|
|
||||||
DEERFLOW_WRITE_GOLDEN=1 PYTHONPATH=. uv run pytest tests/test_replay_golden.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## Run (no key)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd backend && PYTHONPATH=. uv run pytest tests/test_replay_golden.py # Layer 1
|
|
||||||
cd frontend && pnpm exec playwright test -c playwright.real-backend.config.ts # Layer 2
|
|
||||||
```
|
|
||||||
|
|
||||||
## CI
|
|
||||||
|
|
||||||
`.github/workflows/replay-e2e.yml` runs both layers on changes to **either** side
|
|
||||||
of the contract (`frontend/**`, `backend/app/gateway/**`,
|
|
||||||
`backend/packages/harness/**`, fixtures). DOM assertions are the gate; the rendered
|
|
||||||
screenshot + Playwright HTML report are uploaded as a CI artifact.
|
|
||||||
|
|
||||||
## Known limitations
|
|
||||||
|
|
||||||
- Visual regression baselines are OS-specific, so they are a **local dev gate
|
|
||||||
only** (gitignored); CI uploads the render as an artifact for human review
|
|
||||||
instead of hard-asserting a cross-OS baseline.
|
|
||||||
- Fixtures are coupled to the recording-time prompt; if new
|
|
||||||
environment-dependent content enters the system prompt, extend the
|
|
||||||
normalization in `replay_provider.py` (or pin it in `build_config_yaml`).
|
|
||||||
- Re-record a scenario if the agent graph changes how many model calls it makes
|
|
||||||
— the replay raises loudly on a hash miss pointing at the divergence.
|
|
||||||
@@ -127,8 +127,8 @@ complex_agent = create_agent_for_task("high")
|
|||||||
## How It Works
|
## How It Works
|
||||||
|
|
||||||
1. When `make_lead_agent(config)` is called, it extracts `is_plan_mode` from `config.configurable`
|
1. When `make_lead_agent(config)` is called, it extracts `is_plan_mode` from `config.configurable`
|
||||||
2. The config is passed to `build_middlewares(config)`
|
2. The config is passed to `_build_middlewares(config)`
|
||||||
3. `build_middlewares()` reads `is_plan_mode` and calls `_create_todo_list_middleware(is_plan_mode)`
|
3. `_build_middlewares()` reads `is_plan_mode` and calls `_create_todo_list_middleware(is_plan_mode)`
|
||||||
4. If `is_plan_mode=True`, a `TodoListMiddleware` instance is created and added to the middleware chain
|
4. If `is_plan_mode=True`, a `TodoListMiddleware` instance is created and added to the middleware chain
|
||||||
5. The middleware automatically adds a `write_todos` tool to the agent's toolset
|
5. The middleware automatically adds a `write_todos` tool to the agent's toolset
|
||||||
6. The agent can use this tool to manage tasks during execution
|
6. The agent can use this tool to manage tasks during execution
|
||||||
@@ -141,7 +141,7 @@ make_lead_agent(config)
|
|||||||
│
|
│
|
||||||
├─> Extracts: is_plan_mode = config.configurable.get("is_plan_mode", False)
|
├─> Extracts: is_plan_mode = config.configurable.get("is_plan_mode", False)
|
||||||
│
|
│
|
||||||
└─> build_middlewares(config)
|
└─> _build_middlewares(config)
|
||||||
│
|
│
|
||||||
├─> ThreadDataMiddleware
|
├─> ThreadDataMiddleware
|
||||||
├─> SandboxMiddleware
|
├─> SandboxMiddleware
|
||||||
@@ -156,7 +156,7 @@ make_lead_agent(config)
|
|||||||
### Agent Module
|
### Agent Module
|
||||||
- **Location**: `packages/harness/deerflow/agents/lead_agent/agent.py`
|
- **Location**: `packages/harness/deerflow/agents/lead_agent/agent.py`
|
||||||
- **Function**: `_create_todo_list_middleware(is_plan_mode: bool)` - Creates TodoListMiddleware if plan mode is enabled
|
- **Function**: `_create_todo_list_middleware(is_plan_mode: bool)` - Creates TodoListMiddleware if plan mode is enabled
|
||||||
- **Function**: `build_middlewares(config: RunnableConfig)` - Builds middleware chain based on runtime config
|
- **Function**: `_build_middlewares(config: RunnableConfig)` - Builds middleware chain based on runtime config
|
||||||
- **Function**: `make_lead_agent(config: RunnableConfig)` - Creates agent with appropriate middlewares
|
- **Function**: `make_lead_agent(config: RunnableConfig)` - Creates agent with appropriate middlewares
|
||||||
|
|
||||||
### Runtime Configuration
|
### Runtime Configuration
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ middleware, and the async path inside ``TitleMiddleware``. Any new in-graph
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
@@ -47,9 +48,12 @@ from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
|
|||||||
from deerflow.skills.types import Skill
|
from deerflow.skills.types import Skill
|
||||||
from deerflow.tracing import build_tracing_callbacks
|
from deerflow.tracing import build_tracing_callbacks
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
if TYPE_CHECKING:
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
_BOOTSTRAP_SKILL_NAMES = {"bootstrap"}
|
from deerflow.tools.builtins.tool_search import DeferredToolSetup
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_runtime_config(config: RunnableConfig) -> dict:
|
def _get_runtime_config(config: RunnableConfig) -> dict:
|
||||||
@@ -267,31 +271,21 @@ Being proactive with task management demonstrates thoroughness and ensures all r
|
|||||||
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
|
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
|
||||||
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
|
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
|
||||||
# ClarificationMiddleware should be last to intercept clarification requests after model calls
|
# ClarificationMiddleware should be last to intercept clarification requests after model calls
|
||||||
def build_middlewares(
|
def _build_middlewares(
|
||||||
config: RunnableConfig,
|
config: RunnableConfig,
|
||||||
model_name: str | None,
|
model_name: str | None,
|
||||||
agent_name: str | None = None,
|
agent_name: str | None = None,
|
||||||
custom_middlewares: list[AgentMiddleware] | None = None,
|
custom_middlewares: list[AgentMiddleware] | None = None,
|
||||||
*,
|
*,
|
||||||
available_skills: set[str] | None = None,
|
|
||||||
app_config: AppConfig | None = None,
|
app_config: AppConfig | None = None,
|
||||||
deferred_setup=None,
|
deferred_setup=None,
|
||||||
):
|
):
|
||||||
"""Build the lead-agent middleware chain based on runtime configuration.
|
"""Build middleware chain based on runtime configuration.
|
||||||
|
|
||||||
Public entry point for the lead agent's full middleware composition. Used by
|
|
||||||
``make_lead_agent`` and by the embedded ``DeerFlowClient`` (a lead-agent variant
|
|
||||||
that needs the identical chain). Keep this name stable: it is imported across a
|
|
||||||
module boundary, so renames/signature changes ripple into ``client.py``.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Runtime configuration containing configurable options like is_plan_mode.
|
config: Runtime configuration containing configurable options like is_plan_mode.
|
||||||
model_name: Resolved runtime model name; gates vision-only middleware.
|
|
||||||
agent_name: If provided, MemoryMiddleware will use per-agent memory storage.
|
agent_name: If provided, MemoryMiddleware will use per-agent memory storage.
|
||||||
custom_middlewares: Optional list of custom middlewares to inject into the chain.
|
custom_middlewares: Optional list of custom middlewares to inject into the chain.
|
||||||
app_config: Explicit AppConfig; falls back to ``get_app_config()`` when omitted.
|
|
||||||
deferred_setup: Optional deferred-MCP-tool setup that attaches
|
|
||||||
``DeferredToolFilterMiddleware`` when ``tool_search`` is enabled.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of middleware instances.
|
List of middleware instances.
|
||||||
@@ -305,13 +299,6 @@ def build_middlewares(
|
|||||||
|
|
||||||
middlewares.append(DynamicContextMiddleware(agent_name=agent_name, app_config=resolved_app_config))
|
middlewares.append(DynamicContextMiddleware(agent_name=agent_name, app_config=resolved_app_config))
|
||||||
|
|
||||||
# Deterministically load a full SKILL.md when the user starts the turn with
|
|
||||||
# /skill-name. This keeps the base system prompt metadata-only while giving
|
|
||||||
# explicit user activation priority over model-side relevance guessing.
|
|
||||||
from deerflow.agents.middlewares.skill_activation_middleware import SkillActivationMiddleware
|
|
||||||
|
|
||||||
middlewares.append(SkillActivationMiddleware(available_skills=available_skills, app_config=resolved_app_config))
|
|
||||||
|
|
||||||
# Add summarization middleware if enabled
|
# Add summarization middleware if enabled
|
||||||
summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
|
summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
|
||||||
if summarization_middleware is not None:
|
if summarization_middleware is not None:
|
||||||
@@ -377,9 +364,29 @@ def build_middlewares(
|
|||||||
return middlewares
|
return middlewares
|
||||||
|
|
||||||
|
|
||||||
|
def _assemble_deferred(filtered_tools: list[BaseTool], *, enabled: bool) -> tuple[list[BaseTool], DeferredToolSetup]:
|
||||||
|
"""Build the final tool list + deferred setup from a policy-filtered list.
|
||||||
|
|
||||||
|
Call AFTER tool-policy filtering so the deferred catalog never exposes a
|
||||||
|
tool the agent is not allowed to use. Fail-closed: if tool_search is enabled
|
||||||
|
and MCP tools survived filtering but no deferred set was recovered, raise
|
||||||
|
rather than silently binding their full schemas to the model.
|
||||||
|
"""
|
||||||
|
from deerflow.tools.builtins.tool_search import build_deferred_tool_setup
|
||||||
|
from deerflow.tools.mcp_metadata import is_mcp_tool
|
||||||
|
|
||||||
|
deferred_setup = build_deferred_tool_setup(filtered_tools, enabled=enabled)
|
||||||
|
if enabled and not deferred_setup.deferred_names and any(is_mcp_tool(t) for t in filtered_tools):
|
||||||
|
raise RuntimeError("tool_search enabled and MCP tools survived policy filtering, but no deferred set was recovered — refusing to bind MCP schemas (fail-closed).")
|
||||||
|
final_tools = list(filtered_tools)
|
||||||
|
if deferred_setup.tool_search_tool:
|
||||||
|
final_tools.append(deferred_setup.tool_search_tool)
|
||||||
|
return final_tools, deferred_setup
|
||||||
|
|
||||||
|
|
||||||
def _available_skill_names(agent_config, is_bootstrap: bool) -> set[str] | None:
|
def _available_skill_names(agent_config, is_bootstrap: bool) -> set[str] | None:
|
||||||
if is_bootstrap:
|
if is_bootstrap:
|
||||||
return set(_BOOTSTRAP_SKILL_NAMES)
|
return {"bootstrap"}
|
||||||
if agent_config and agent_config.skills is not None:
|
if agent_config and agent_config.skills is not None:
|
||||||
return set(agent_config.skills)
|
return set(agent_config.skills)
|
||||||
return None
|
return None
|
||||||
@@ -410,7 +417,6 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
|||||||
# Lazy import to avoid circular dependency
|
# Lazy import to avoid circular dependency
|
||||||
from deerflow.tools import get_available_tools
|
from deerflow.tools import get_available_tools
|
||||||
from deerflow.tools.builtins import setup_agent, update_agent
|
from deerflow.tools.builtins import setup_agent, update_agent
|
||||||
from deerflow.tools.builtins.tool_search import assemble_deferred_tools
|
|
||||||
|
|
||||||
cfg = _get_runtime_config(config)
|
cfg = _get_runtime_config(config)
|
||||||
resolved_app_config = app_config
|
resolved_app_config = app_config
|
||||||
@@ -485,25 +491,17 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
|||||||
|
|
||||||
if is_bootstrap:
|
if is_bootstrap:
|
||||||
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
|
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
|
||||||
# Keep the bootstrap skill set intentionally narrow so agent creation
|
|
||||||
# remains deterministic before the custom agent's own config exists.
|
|
||||||
raw_tools = get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent]
|
raw_tools = get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent]
|
||||||
filtered = filter_tools_by_skill_allowed_tools(raw_tools, skills_for_tool_policy)
|
filtered = filter_tools_by_skill_allowed_tools(raw_tools, skills_for_tool_policy)
|
||||||
final_tools, setup = assemble_deferred_tools(filtered, enabled=resolved_app_config.tool_search.enabled)
|
final_tools, setup = _assemble_deferred(filtered, enabled=resolved_app_config.tool_search.enabled)
|
||||||
return create_agent(
|
return create_agent(
|
||||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config, attach_tracing=False),
|
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config, attach_tracing=False),
|
||||||
tools=final_tools,
|
tools=final_tools,
|
||||||
middleware=build_middlewares(
|
middleware=_build_middlewares(config, model_name=model_name, app_config=resolved_app_config, deferred_setup=setup),
|
||||||
config,
|
|
||||||
model_name=model_name,
|
|
||||||
available_skills=set(_BOOTSTRAP_SKILL_NAMES),
|
|
||||||
app_config=resolved_app_config,
|
|
||||||
deferred_setup=setup,
|
|
||||||
),
|
|
||||||
system_prompt=apply_prompt_template(
|
system_prompt=apply_prompt_template(
|
||||||
subagent_enabled=subagent_enabled,
|
subagent_enabled=subagent_enabled,
|
||||||
max_concurrent_subagents=max_concurrent_subagents,
|
max_concurrent_subagents=max_concurrent_subagents,
|
||||||
available_skills=set(_BOOTSTRAP_SKILL_NAMES),
|
available_skills=set(["bootstrap"]),
|
||||||
app_config=resolved_app_config,
|
app_config=resolved_app_config,
|
||||||
deferred_names=setup.deferred_names,
|
deferred_names=setup.deferred_names,
|
||||||
),
|
),
|
||||||
@@ -516,23 +514,16 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
|||||||
# Default lead agent (unchanged behavior)
|
# Default lead agent (unchanged behavior)
|
||||||
raw_tools = get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=resolved_app_config)
|
raw_tools = get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=resolved_app_config)
|
||||||
filtered = filter_tools_by_skill_allowed_tools(raw_tools + extra_tools, skills_for_tool_policy)
|
filtered = filter_tools_by_skill_allowed_tools(raw_tools + extra_tools, skills_for_tool_policy)
|
||||||
final_tools, setup = assemble_deferred_tools(filtered, enabled=resolved_app_config.tool_search.enabled)
|
final_tools, setup = _assemble_deferred(filtered, enabled=resolved_app_config.tool_search.enabled)
|
||||||
return create_agent(
|
return create_agent(
|
||||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config, attach_tracing=False),
|
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config, attach_tracing=False),
|
||||||
tools=final_tools,
|
tools=final_tools,
|
||||||
middleware=build_middlewares(
|
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name, app_config=resolved_app_config, deferred_setup=setup),
|
||||||
config,
|
|
||||||
model_name=model_name,
|
|
||||||
agent_name=agent_name,
|
|
||||||
available_skills=available_skills,
|
|
||||||
app_config=resolved_app_config,
|
|
||||||
deferred_setup=setup,
|
|
||||||
),
|
|
||||||
system_prompt=apply_prompt_template(
|
system_prompt=apply_prompt_template(
|
||||||
subagent_enabled=subagent_enabled,
|
subagent_enabled=subagent_enabled,
|
||||||
max_concurrent_subagents=max_concurrent_subagents,
|
max_concurrent_subagents=max_concurrent_subagents,
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
available_skills=available_skills,
|
available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None,
|
||||||
app_config=resolved_app_config,
|
app_config=resolved_app_config,
|
||||||
deferred_names=setup.deferred_names,
|
deferred_names=setup.deferred_names,
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from deerflow.config.agents_config import load_agent_soul
|
|||||||
from deerflow.skills.storage import get_or_new_skill_storage
|
from deerflow.skills.storage import get_or_new_skill_storage
|
||||||
from deerflow.skills.types import Skill, SkillCategory
|
from deerflow.skills.types import Skill, SkillCategory
|
||||||
from deerflow.subagents import get_available_subagent_names
|
from deerflow.subagents import get_available_subagent_names
|
||||||
from deerflow.tools.builtins.tool_search import get_deferred_tools_prompt_section
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from deerflow.config.app_config import AppConfig
|
from deerflow.config.app_config import AppConfig
|
||||||
@@ -586,11 +585,7 @@ def _get_memory_context(agent_name: str | None = None, *, app_config: AppConfig
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
|
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
|
||||||
memory_content = format_memory_for_injection(
|
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
|
||||||
memory_data,
|
|
||||||
max_tokens=config.max_injection_tokens,
|
|
||||||
use_tiktoken=(config.token_counting == "tiktoken"),
|
|
||||||
)
|
|
||||||
|
|
||||||
if not memory_content.strip():
|
if not memory_content.strip():
|
||||||
return ""
|
return ""
|
||||||
@@ -629,11 +624,6 @@ You have access to skills that provide optimized workflows for specific tasks. E
|
|||||||
4. Load referenced resources only when needed during execution
|
4. Load referenced resources only when needed during execution
|
||||||
5. Follow the skill's instructions precisely
|
5. Follow the skill's instructions precisely
|
||||||
|
|
||||||
**Explicit Slash Skill Activation:**
|
|
||||||
- If the user starts a request with `/<skill-name>`, that skill was explicitly requested for the current turn.
|
|
||||||
- Follow the activated skill before choosing a general workflow.
|
|
||||||
- The runtime injects the activated skill content for explicit slash activations; do not call `read_file` for that SKILL.md again unless the injected skill references supporting resources you need.
|
|
||||||
|
|
||||||
**Skills are located at:** {container_base_path}
|
**Skills are located at:** {container_base_path}
|
||||||
{skill_evolution_section}
|
{skill_evolution_section}
|
||||||
{skills_list}
|
{skills_list}
|
||||||
@@ -703,6 +693,19 @@ Rules:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_deferred_tools_prompt_section(*, deferred_names: frozenset[str] = frozenset()) -> str:
|
||||||
|
"""Generate <available-deferred-tools> from an explicit deferred-name set.
|
||||||
|
|
||||||
|
Lists only names so the agent knows what exists and can use tool_search to
|
||||||
|
load them. Returns empty string when there are no deferred tools. The set is
|
||||||
|
computed at agent build time (after tool-policy filtering) and passed in.
|
||||||
|
"""
|
||||||
|
if not deferred_names:
|
||||||
|
return ""
|
||||||
|
names = "\n".join(sorted(deferred_names))
|
||||||
|
return f"<available-deferred-tools>\n{names}\n</available-deferred-tools>"
|
||||||
|
|
||||||
|
|
||||||
def _build_acp_section(*, app_config: AppConfig | None = None) -> str:
|
def _build_acp_section(*, app_config: AppConfig | None = None) -> str:
|
||||||
"""Build the ACP agent prompt section, only if ACP agents are configured."""
|
"""Build the ACP agent prompt section, only if ACP agents are configured."""
|
||||||
if app_config is None:
|
if app_config is None:
|
||||||
|
|||||||
@@ -1,15 +1,8 @@
|
|||||||
"""Prompt templates for memory update and injection."""
|
"""Prompt templates for memory update and injection."""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
import threading
|
from typing import Any
|
||||||
import time
|
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import tiktoken
|
import tiktoken
|
||||||
@@ -167,137 +160,26 @@ Rules:
|
|||||||
Return ONLY valid JSON."""
|
Return ONLY valid JSON."""
|
||||||
|
|
||||||
|
|
||||||
# Module-level tiktoken encoding cache. Populated lazily on first use;
|
def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
|
||||||
# subsequent calls are a dict lookup (no network I/O). Pre-warming at
|
|
||||||
# startup via :func:`warm_tiktoken_cache` avoids blocking a request on the
|
|
||||||
# (potentially slow) first ``get_encoding`` call.
|
|
||||||
#
|
|
||||||
# A *failed* load is cached as a ``(None, monotonic_timestamp)`` tuple so that
|
|
||||||
# a network-restricted environment does not re-attempt the blocking BPE
|
|
||||||
# download on every subsequent call. After ``_TIKTOKEN_RETRY_COOLDOWN_S`` the
|
|
||||||
# failure is allowed to expire so a transient network outage can self-heal back
|
|
||||||
# to accurate tiktoken counting without a process restart. A load already in
|
|
||||||
# progress is cached as ``_TIKTOKEN_ENCODING_LOADING`` so concurrent callers
|
|
||||||
# fall back immediately instead of spawning more blocking
|
|
||||||
# ``tiktoken.get_encoding`` threads. Use the ``memory.token_counting: char``
|
|
||||||
# config to skip tiktoken entirely.
|
|
||||||
_TIKTOKEN_ENCODING_MISSING = object()
|
|
||||||
_TIKTOKEN_ENCODING_LOADING = object()
|
|
||||||
# Cooldown before a *failed* tiktoken load is re-attempted. This is an internal
|
|
||||||
# tuning constant rather than a user-facing config: it only affects how quickly
|
|
||||||
# the default ``tiktoken`` mode self-heals after a transient network outage.
|
|
||||||
# Deployments that want to avoid tiktoken's network dependency entirely should
|
|
||||||
# set ``memory.token_counting: char`` instead of tuning this value.
|
|
||||||
_TIKTOKEN_RETRY_COOLDOWN_S = 600.0
|
|
||||||
_tiktoken_encoding_cache: dict[str, Any] = {}
|
|
||||||
_tiktoken_encoding_cache_lock = threading.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encoding | None:
|
|
||||||
"""Return a cached tiktoken encoding, or ``None`` on failure / unavailability.
|
|
||||||
|
|
||||||
On the very first call for a given *encoding_name*, tiktoken may need to
|
|
||||||
download the BPE data from ``openaipublic.blob.core.windows.net``. In
|
|
||||||
network-restricted environments (e.g. deployments behind the GFW) this
|
|
||||||
download can block for tens of minutes before the OS TCP timeout kicks in.
|
|
||||||
The caller must therefore be prepared for this to block and should run it
|
|
||||||
off the event loop (e.g. via ``asyncio.to_thread``).
|
|
||||||
|
|
||||||
A failed load is remembered (with a timestamp) so subsequent calls fall
|
|
||||||
back immediately to character-based estimation instead of re-triggering the
|
|
||||||
blocking download. The failure expires after ``_TIKTOKEN_RETRY_COOLDOWN_S``
|
|
||||||
so a transient outage can self-heal without a restart. A load already in
|
|
||||||
progress is also remembered so that a timed-out caller does not leave a
|
|
||||||
window where later requests start more blocking ``get_encoding`` calls.
|
|
||||||
"""
|
|
||||||
if not TIKTOKEN_AVAILABLE:
|
|
||||||
return None
|
|
||||||
|
|
||||||
with _tiktoken_encoding_cache_lock:
|
|
||||||
cached = _tiktoken_encoding_cache.get(encoding_name, _TIKTOKEN_ENCODING_MISSING)
|
|
||||||
if cached is _TIKTOKEN_ENCODING_LOADING:
|
|
||||||
return None
|
|
||||||
if isinstance(cached, tuple):
|
|
||||||
# Cached failure: (None, failed_at). Retry only after cooldown.
|
|
||||||
_, failed_at = cached
|
|
||||||
if time.monotonic() - failed_at < _TIKTOKEN_RETRY_COOLDOWN_S:
|
|
||||||
return None
|
|
||||||
cached = _TIKTOKEN_ENCODING_MISSING
|
|
||||||
if cached is not _TIKTOKEN_ENCODING_MISSING:
|
|
||||||
return cast("tiktoken.Encoding", cached)
|
|
||||||
_tiktoken_encoding_cache[encoding_name] = _TIKTOKEN_ENCODING_LOADING
|
|
||||||
|
|
||||||
try:
|
|
||||||
encoding = tiktoken.get_encoding(encoding_name)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to load tiktoken encoding %r; falling back to char-based estimation", encoding_name, exc_info=True)
|
|
||||||
with _tiktoken_encoding_cache_lock:
|
|
||||||
_tiktoken_encoding_cache[encoding_name] = (None, time.monotonic())
|
|
||||||
return None
|
|
||||||
|
|
||||||
with _tiktoken_encoding_cache_lock:
|
|
||||||
_tiktoken_encoding_cache[encoding_name] = encoding
|
|
||||||
return encoding
|
|
||||||
|
|
||||||
|
|
||||||
def _char_based_token_estimate(text: str) -> int:
|
|
||||||
"""Network-free token estimate that accounts for CJK density.
|
|
||||||
|
|
||||||
The plain ``len(text) // 4`` heuristic is reasonable for English/code
|
|
||||||
(~4 chars per token) but significantly under-estimates token counts for
|
|
||||||
Chinese, Japanese, and Korean text, where the ratio is closer to 1.5-2
|
|
||||||
characters per token. Counting CJK characters separately (~2 chars per
|
|
||||||
token) avoids over-filling the injection budget for CJK-heavy memory
|
|
||||||
content.
|
|
||||||
"""
|
|
||||||
cjk = sum(
|
|
||||||
1
|
|
||||||
for ch in text
|
|
||||||
if "\u4e00" <= ch <= "\u9fff" # CJK Unified Ideographs
|
|
||||||
or "\u3040" <= ch <= "\u30ff" # Hiragana + Katakana
|
|
||||||
or "\uac00" <= ch <= "\ud7a3" # Hangul syllables
|
|
||||||
)
|
|
||||||
return (len(text) - cjk) // 4 + cjk // 2
|
|
||||||
|
|
||||||
|
|
||||||
def _count_tokens(text: str, encoding_name: str = "cl100k_base", *, use_tiktoken: bool = True) -> int:
|
|
||||||
"""Count tokens in text using tiktoken.
|
"""Count tokens in text using tiktoken.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: The text to count tokens for.
|
text: The text to count tokens for.
|
||||||
encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5).
|
encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5).
|
||||||
use_tiktoken: When ``False``, skip tiktoken entirely and use the
|
|
||||||
network-free character-based estimate. This guarantees no BPE
|
|
||||||
download is attempted (see ``memory.token_counting`` config).
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The number of tokens in the text.
|
The number of tokens in the text.
|
||||||
"""
|
"""
|
||||||
if not use_tiktoken:
|
if not TIKTOKEN_AVAILABLE:
|
||||||
return _char_based_token_estimate(text)
|
# Fallback to character-based estimation if tiktoken is not available
|
||||||
|
return len(text) // 4
|
||||||
encoding = _get_tiktoken_encoding(encoding_name)
|
|
||||||
if encoding is None:
|
|
||||||
# Fallback to CJK-aware character estimation if tiktoken is not
|
|
||||||
# available or the encoding failed to load.
|
|
||||||
return _char_based_token_estimate(text)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
encoding = tiktoken.get_encoding(encoding_name)
|
||||||
return len(encoding.encode(text))
|
return len(encoding.encode(text))
|
||||||
except Exception:
|
except Exception:
|
||||||
# Fallback to CJK-aware character estimation on error.
|
# Fallback to character-based estimation on error
|
||||||
return _char_based_token_estimate(text)
|
return len(text) // 4
|
||||||
|
|
||||||
|
|
||||||
def warm_tiktoken_cache() -> bool:
|
|
||||||
"""Pre-warm the tiktoken encoding cache.
|
|
||||||
|
|
||||||
Call at startup (off the event loop) so the first request never blocks
|
|
||||||
on the BPE download. Returns ``True`` if the encoding was loaded
|
|
||||||
successfully (or was already cached), ``False`` if tiktoken is
|
|
||||||
unavailable or the download failed.
|
|
||||||
"""
|
|
||||||
return _get_tiktoken_encoding("cl100k_base") is not None
|
|
||||||
|
|
||||||
|
|
||||||
def _coerce_confidence(value: Any, default: float = 0.0) -> float:
|
def _coerce_confidence(value: Any, default: float = 0.0) -> float:
|
||||||
@@ -316,15 +198,12 @@ def _coerce_confidence(value: Any, default: float = 0.0) -> float:
|
|||||||
return max(0.0, min(1.0, confidence))
|
return max(0.0, min(1.0, confidence))
|
||||||
|
|
||||||
|
|
||||||
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000, *, use_tiktoken: bool = True) -> str:
|
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str:
|
||||||
"""Format memory data for injection into system prompt.
|
"""Format memory data for injection into system prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_data: The memory data dictionary.
|
memory_data: The memory data dictionary.
|
||||||
max_tokens: Maximum tokens to use (counted via tiktoken for accuracy).
|
max_tokens: Maximum tokens to use (counted via tiktoken for accuracy).
|
||||||
use_tiktoken: When ``False``, all token counting uses the network-free
|
|
||||||
character-based estimate instead of tiktoken (see
|
|
||||||
``memory.token_counting`` config). Defaults to ``True``.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Formatted memory string for system prompt injection.
|
Formatted memory string for system prompt injection.
|
||||||
@@ -386,10 +265,10 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
|||||||
# Compute token count for existing sections once, then account
|
# Compute token count for existing sections once, then account
|
||||||
# incrementally for each fact line to avoid full-string re-tokenization.
|
# incrementally for each fact line to avoid full-string re-tokenization.
|
||||||
base_text = "\n\n".join(sections)
|
base_text = "\n\n".join(sections)
|
||||||
base_tokens = _count_tokens(base_text, use_tiktoken=use_tiktoken) if base_text else 0
|
base_tokens = _count_tokens(base_text) if base_text else 0
|
||||||
# Account for the separator between existing sections and the facts section.
|
# Account for the separator between existing sections and the facts section.
|
||||||
facts_header = "Facts:\n"
|
facts_header = "Facts:\n"
|
||||||
separator_tokens = _count_tokens("\n\n" + facts_header, use_tiktoken=use_tiktoken) if base_text else _count_tokens(facts_header, use_tiktoken=use_tiktoken)
|
separator_tokens = _count_tokens("\n\n" + facts_header) if base_text else _count_tokens(facts_header)
|
||||||
running_tokens = base_tokens + separator_tokens
|
running_tokens = base_tokens + separator_tokens
|
||||||
|
|
||||||
fact_lines: list[str] = []
|
fact_lines: list[str] = []
|
||||||
@@ -410,7 +289,7 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
|||||||
|
|
||||||
# Each additional line is preceded by a newline (except the first).
|
# Each additional line is preceded by a newline (except the first).
|
||||||
line_text = ("\n" + line) if fact_lines else line
|
line_text = ("\n" + line) if fact_lines else line
|
||||||
line_tokens = _count_tokens(line_text, use_tiktoken=use_tiktoken)
|
line_tokens = _count_tokens(line_text)
|
||||||
|
|
||||||
if running_tokens + line_tokens <= max_tokens:
|
if running_tokens + line_tokens <= max_tokens:
|
||||||
fact_lines.append(line)
|
fact_lines.append(line)
|
||||||
@@ -426,9 +305,8 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
|||||||
|
|
||||||
result = "\n\n".join(sections)
|
result = "\n\n".join(sections)
|
||||||
|
|
||||||
# Use accurate token counting with tiktoken (or the char-based estimate
|
# Use accurate token counting with tiktoken
|
||||||
# when use_tiktoken is False).
|
token_count = _count_tokens(result)
|
||||||
token_count = _count_tokens(result, use_tiktoken=use_tiktoken)
|
|
||||||
if token_count > max_tokens:
|
if token_count > max_tokens:
|
||||||
# Truncate to fit within token limit
|
# Truncate to fit within token limit
|
||||||
# Estimate characters to remove based on token ratio
|
# Estimate characters to remove based on token ratio
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ Date-update format:
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
@@ -44,12 +43,6 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Upper bound (seconds) for a single _inject() offload. If the warm-up at
|
|
||||||
# gateway startup failed silently, the first request may still hit a cold
|
|
||||||
# tiktoken BPE download that blocks until the OS TCP timeout (~26 min).
|
|
||||||
# This cap ensures the request degrades gracefully instead of hanging.
|
|
||||||
_INJECT_TIMEOUT_SECONDS = 5.0
|
|
||||||
|
|
||||||
_DATE_RE = re.compile(r"<current_date>([^<]+)</current_date>")
|
_DATE_RE = re.compile(r"<current_date>([^<]+)</current_date>")
|
||||||
_DYNAMIC_CONTEXT_REMINDER_KEY = "dynamic_context_reminder"
|
_DYNAMIC_CONTEXT_REMINDER_KEY = "dynamic_context_reminder"
|
||||||
_SUMMARY_MESSAGE_NAME = "summary"
|
_SUMMARY_MESSAGE_NAME = "summary"
|
||||||
@@ -208,25 +201,4 @@ class DynamicContextMiddleware(AgentMiddleware):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
async def abefore_agent(self, state, runtime: Runtime) -> dict | None:
|
async def abefore_agent(self, state, runtime: Runtime) -> dict | None:
|
||||||
# _inject() performs synchronous file I/O (memory JSON loading) and
|
return self._inject(state)
|
||||||
# potentially blocking network calls (tiktoken encoding download on
|
|
||||||
# first use). Offload to a thread so the event loop is never blocked
|
|
||||||
# — a blocking call here starves all concurrent HTTP handlers (auth,
|
|
||||||
# SSE heartbeats, etc.). See issue #3402.
|
|
||||||
#
|
|
||||||
# Bounded timeout: if startup warm-up failed silently (e.g. network
|
|
||||||
# blip during deploy), the first request's cold tiktoken download can
|
|
||||||
# block for tens of minutes (OS TCP timeout). Time-box injection so
|
|
||||||
# the request degrades gracefully (no memory context) rather than
|
|
||||||
# hanging.
|
|
||||||
try:
|
|
||||||
return await asyncio.wait_for(
|
|
||||||
asyncio.to_thread(self._inject, state),
|
|
||||||
timeout=_INJECT_TIMEOUT_SECONDS,
|
|
||||||
)
|
|
||||||
except TimeoutError:
|
|
||||||
logger.warning(
|
|
||||||
"DynamicContextMiddleware: injection timed out (%.1fs); skipping memory/date injection for this turn",
|
|
||||||
_INJECT_TIMEOUT_SECONDS,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|||||||
@@ -1,289 +0,0 @@
|
|||||||
"""Middleware for explicit slash skill activation."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import html
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from collections.abc import Awaitable, Callable
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import TYPE_CHECKING, override
|
|
||||||
|
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
|
||||||
from langchain.agents.middleware.types import ModelRequest, ModelResponse
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
|
||||||
|
|
||||||
from deerflow.skills.slash import parse_slash_skill_reference, resolve_slash_skill
|
|
||||||
from deerflow.skills.storage import get_or_new_skill_storage
|
|
||||||
from deerflow.skills.storage.skill_storage import SkillStorage
|
|
||||||
from deerflow.skills.types import SKILL_MD_FILE
|
|
||||||
from deerflow.utils.messages import get_original_user_content_text
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from deerflow.config.app_config import AppConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_SLASH_SKILL_ACTIVATION_KEY = "slash_skill_activation"
|
|
||||||
_SLASH_SKILL_ACTIVATION_TARGET_ID_KEY = "slash_skill_activation_target_id"
|
|
||||||
_SUMMARY_MESSAGE_NAME = "summary"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
|
||||||
class _Activation:
|
|
||||||
skill_name: str
|
|
||||||
category: str
|
|
||||||
container_file_path: str
|
|
||||||
skill_content: str
|
|
||||||
content_hash: str
|
|
||||||
remaining_text: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
|
||||||
class _ActivationResolution:
|
|
||||||
activation: _Activation | None = None
|
|
||||||
failure_message: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def is_slash_skill_activation_reminder(message: object) -> bool:
|
|
||||||
"""Return whether a message is hidden slash-skill activation context."""
|
|
||||||
return isinstance(message, HumanMessage) and bool(message.additional_kwargs.get(_SLASH_SKILL_ACTIVATION_KEY))
|
|
||||||
|
|
||||||
|
|
||||||
def _is_user_activation_target(message: object) -> bool:
|
|
||||||
if not isinstance(message, HumanMessage):
|
|
||||||
return False
|
|
||||||
if message.name == _SUMMARY_MESSAGE_NAME:
|
|
||||||
return False
|
|
||||||
if message.additional_kwargs.get("hide_from_ui"):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class SkillActivationMiddleware(AgentMiddleware):
|
|
||||||
"""Inject full SKILL.md content when the user explicitly types /skill-name."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
available_skills: set[str] | None = None,
|
|
||||||
app_config: AppConfig | None = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self._available_skills = set(available_skills) if available_skills is not None else None
|
|
||||||
self._app_config = app_config
|
|
||||||
|
|
||||||
def _storage(self) -> SkillStorage:
|
|
||||||
if self._app_config is not None:
|
|
||||||
return get_or_new_skill_storage(app_config=self._app_config)
|
|
||||||
return get_or_new_skill_storage()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _read_skill_content(skill_file: Path, skills_root: Path) -> str:
|
|
||||||
if skill_file.name != SKILL_MD_FILE:
|
|
||||||
raise ValueError(f"Expected {SKILL_MD_FILE}, got {skill_file.name}")
|
|
||||||
resolved_root = skills_root.resolve()
|
|
||||||
resolved_file = skill_file.resolve()
|
|
||||||
try:
|
|
||||||
resolved_file.relative_to(resolved_root)
|
|
||||||
except ValueError as exc:
|
|
||||||
raise ValueError("Resolved skill file must stay within the configured skills root.") from exc
|
|
||||||
if not resolved_file.is_file():
|
|
||||||
raise FileNotFoundError(resolved_file)
|
|
||||||
return resolved_file.read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
def _resolve_activation(self, text: str) -> _ActivationResolution | None:
|
|
||||||
reference = parse_slash_skill_reference(text)
|
|
||||||
if reference is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
storage = self._storage()
|
|
||||||
skills = storage.load_skills(enabled_only=False)
|
|
||||||
skill = next((candidate for candidate in skills if candidate.name == reference.name), None)
|
|
||||||
if skill is None:
|
|
||||||
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` is not installed.")
|
|
||||||
if not skill.enabled:
|
|
||||||
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` is installed but disabled. Enable it before using slash activation.")
|
|
||||||
if self._available_skills is not None and reference.name not in self._available_skills:
|
|
||||||
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` is not available for this agent.")
|
|
||||||
|
|
||||||
resolved = resolve_slash_skill(
|
|
||||||
text,
|
|
||||||
skills,
|
|
||||||
available_skills=self._available_skills,
|
|
||||||
container_base_path=storage.get_container_root(),
|
|
||||||
)
|
|
||||||
if resolved is None:
|
|
||||||
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` could not be resolved.")
|
|
||||||
|
|
||||||
try:
|
|
||||||
skill_content = self._read_skill_content(resolved.skill.skill_file, storage.get_skills_root_path())
|
|
||||||
except (OSError, ValueError):
|
|
||||||
logger.exception("Failed to read slash-activated skill %s", resolved.skill.name)
|
|
||||||
return _ActivationResolution(failure_message=f"Skill `/{reference.name}` could not be loaded safely. Please check the skill installation.")
|
|
||||||
|
|
||||||
content_hash = hashlib.sha256(skill_content.encode("utf-8")).hexdigest()
|
|
||||||
return _ActivationResolution(
|
|
||||||
activation=_Activation(
|
|
||||||
skill_name=resolved.skill.name,
|
|
||||||
category=str(resolved.skill.category),
|
|
||||||
container_file_path=resolved.container_file_path,
|
|
||||||
skill_content=skill_content,
|
|
||||||
content_hash=content_hash,
|
|
||||||
remaining_text=resolved.remaining_text,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _build_activation_reminder(activation: _Activation) -> str:
|
|
||||||
user_request = activation.remaining_text or ("No additional task text was provided after the slash skill command. Ask the user what they want to do with this skill if the next step is unclear.")
|
|
||||||
escaped_user_request = html.escape(user_request, quote=False)
|
|
||||||
escaped_skill_content = html.escape(activation.skill_content, quote=False)
|
|
||||||
escaped_skill_name = html.escape(activation.skill_name, quote=True)
|
|
||||||
escaped_category = html.escape(activation.category, quote=True)
|
|
||||||
escaped_path = html.escape(activation.container_file_path, quote=True)
|
|
||||||
escaped_content_hash = html.escape(activation.content_hash, quote=True)
|
|
||||||
return f"""<slash_skill_activation>
|
|
||||||
The user explicitly activated the `{activation.skill_name}` skill for this turn.
|
|
||||||
Treat the task text as:
|
|
||||||
<user_request>
|
|
||||||
{escaped_user_request}
|
|
||||||
</user_request>
|
|
||||||
|
|
||||||
Follow this skill before choosing a general workflow. Load supporting resources from the same skill directory only when needed.
|
|
||||||
|
|
||||||
<skill name="{escaped_skill_name}" category="{escaped_category}" path="{escaped_path}" sha256="{escaped_content_hash}">
|
|
||||||
<skill_content encoding="xml-escaped">
|
|
||||||
{escaped_skill_content}
|
|
||||||
</skill_content>
|
|
||||||
</skill>
|
|
||||||
</slash_skill_activation>"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _has_existing_activation_for_target(messages: list, target_index: int, target: HumanMessage) -> bool:
|
|
||||||
if target_index <= 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if target.id:
|
|
||||||
for previous in messages[:target_index]:
|
|
||||||
if not is_slash_skill_activation_reminder(previous):
|
|
||||||
continue
|
|
||||||
target_id = previous.additional_kwargs.get(_SLASH_SKILL_ACTIVATION_TARGET_ID_KEY)
|
|
||||||
if target_id == target.id or previous.id == f"{target.id}__slash_activation":
|
|
||||||
return True
|
|
||||||
|
|
||||||
previous = messages[target_index - 1]
|
|
||||||
return is_slash_skill_activation_reminder(previous)
|
|
||||||
|
|
||||||
def _find_activation_target(self, messages: list) -> tuple[int, HumanMessage, _ActivationResolution] | None:
|
|
||||||
if not messages:
|
|
||||||
return None
|
|
||||||
|
|
||||||
target_index = next((idx for idx in range(len(messages) - 1, -1, -1) if _is_user_activation_target(messages[idx])), None)
|
|
||||||
if target_index is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
target = messages[target_index]
|
|
||||||
if target is None:
|
|
||||||
return None
|
|
||||||
if self._has_existing_activation_for_target(messages, target_index, target):
|
|
||||||
return None
|
|
||||||
|
|
||||||
content = get_original_user_content_text(target.content, target.additional_kwargs)
|
|
||||||
resolution = self._resolve_activation(content)
|
|
||||||
if resolution is None:
|
|
||||||
return None
|
|
||||||
return target_index, target, resolution
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _record_activation(request: ModelRequest, activation: _Activation, *, hook: str) -> None:
|
|
||||||
runtime = getattr(request, "runtime", None)
|
|
||||||
context = getattr(runtime, "context", None)
|
|
||||||
journal = context.get("__run_journal") if isinstance(context, dict) else None
|
|
||||||
if journal is None:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
journal.record_middleware(
|
|
||||||
"skill_activation",
|
|
||||||
name="SkillActivationMiddleware",
|
|
||||||
hook=hook,
|
|
||||||
action="activate",
|
|
||||||
changes={
|
|
||||||
"skill_name": activation.skill_name,
|
|
||||||
"category": activation.category,
|
|
||||||
"path": activation.container_file_path,
|
|
||||||
"content_hash": activation.content_hash,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Failed to record slash skill activation audit event", exc_info=True)
|
|
||||||
|
|
||||||
def _prepare_model_request(self, request: ModelRequest, *, hook: str) -> ModelRequest | AIMessage | None:
|
|
||||||
target_and_resolution = self._find_activation_target(list(request.messages))
|
|
||||||
if target_and_resolution is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
target_index, target, resolution = target_and_resolution
|
|
||||||
if resolution.failure_message:
|
|
||||||
return AIMessage(content=resolution.failure_message)
|
|
||||||
|
|
||||||
activation = resolution.activation
|
|
||||||
if activation is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"SkillActivationMiddleware: activating slash skill %s category=%s path=%s hash=%s",
|
|
||||||
activation.skill_name,
|
|
||||||
activation.category,
|
|
||||||
activation.container_file_path,
|
|
||||||
activation.content_hash,
|
|
||||||
)
|
|
||||||
self._record_activation(request, activation, hook=hook)
|
|
||||||
activation_msg = self._make_activation_message(target, self._build_activation_reminder(activation))
|
|
||||||
messages = list(request.messages)
|
|
||||||
messages.insert(target_index, activation_msg)
|
|
||||||
return request.override(messages=messages)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _make_activation_message(target: HumanMessage, activation_content: str) -> HumanMessage:
|
|
||||||
stable_id = target.id or str(uuid.uuid4())
|
|
||||||
additional_kwargs = {
|
|
||||||
"hide_from_ui": True,
|
|
||||||
_SLASH_SKILL_ACTIVATION_KEY: True,
|
|
||||||
}
|
|
||||||
if target.id:
|
|
||||||
additional_kwargs[_SLASH_SKILL_ACTIVATION_TARGET_ID_KEY] = target.id
|
|
||||||
return HumanMessage(
|
|
||||||
content=activation_content,
|
|
||||||
id=f"{stable_id}__slash_activation",
|
|
||||||
additional_kwargs=additional_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
@override
|
|
||||||
def wrap_model_call(
|
|
||||||
self,
|
|
||||||
request: ModelRequest,
|
|
||||||
handler: Callable[[ModelRequest], ModelResponse],
|
|
||||||
) -> ModelResponse | AIMessage:
|
|
||||||
prepared = self._prepare_model_request(request, hook="wrap_model_call")
|
|
||||||
if prepared is None:
|
|
||||||
return handler(request)
|
|
||||||
if isinstance(prepared, AIMessage):
|
|
||||||
return prepared
|
|
||||||
return handler(prepared)
|
|
||||||
|
|
||||||
@override
|
|
||||||
async def awrap_model_call(
|
|
||||||
self,
|
|
||||||
request: ModelRequest,
|
|
||||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
||||||
) -> ModelResponse | AIMessage:
|
|
||||||
prepared = await asyncio.to_thread(self._prepare_model_request, request, hook="awrap_model_call")
|
|
||||||
if prepared is None:
|
|
||||||
return await handler(request)
|
|
||||||
if isinstance(prepared, AIMessage):
|
|
||||||
return prepared
|
|
||||||
return await handler(prepared)
|
|
||||||
@@ -46,6 +46,11 @@ def _reminder_in_messages(messages: list[Any]) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _completion_reminder_count(messages: list[Any]) -> int:
|
||||||
|
"""Return the number of todo_completion_reminder HumanMessages in *messages*."""
|
||||||
|
return sum(1 for msg in messages if isinstance(msg, HumanMessage) and getattr(msg, "name", None) == "todo_completion_reminder")
|
||||||
|
|
||||||
|
|
||||||
def _format_todos(todos: list[Todo]) -> str:
|
def _format_todos(todos: list[Todo]) -> str:
|
||||||
"""Format a list of Todo items into a human-readable string."""
|
"""Format a list of Todo items into a human-readable string."""
|
||||||
lines: list[str] = []
|
lines: list[str] = []
|
||||||
|
|||||||
+4
-74
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import TYPE_CHECKING, override
|
from typing import override
|
||||||
|
|
||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
@@ -12,48 +12,10 @@ from langgraph.prebuilt.tool_node import ToolCallRequest
|
|||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
|
|
||||||
from deerflow.config.app_config import AppConfig
|
from deerflow.config.app_config import AppConfig
|
||||||
from deerflow.subagents.status_contract import (
|
|
||||||
extract_subagent_status,
|
|
||||||
make_subagent_additional_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from deerflow.tools.builtins.tool_search import DeferredToolSetup
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
|
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
|
||||||
_TASK_TOOL_NAME = "task"
|
|
||||||
|
|
||||||
|
|
||||||
def _stamp_task_subagent_status(message: ToolMessage, *, tool_name: str, error: str | None = None) -> ToolMessage:
|
|
||||||
"""Centralised stamping of ``additional_kwargs.subagent_status``.
|
|
||||||
|
|
||||||
Bytedance/deer-flow issue #3146: the frontend now reads the subagent
|
|
||||||
status from a structured field instead of parsing the leading text of
|
|
||||||
the task tool's return string. That contract is enforced here, in the
|
|
||||||
one place every task tool result flows through, rather than at the 5
|
|
||||||
normal-return + 3 ``Error:`` pre-execution branches inside
|
|
||||||
``task_tool.py``. Centralisation prevents the "added a new return
|
|
||||||
path, forgot the stamp" drift mode.
|
|
||||||
|
|
||||||
For non-``task`` tools this is a no-op so other tools' additional_kwargs
|
|
||||||
conventions are untouched.
|
|
||||||
"""
|
|
||||||
if tool_name != _TASK_TOOL_NAME:
|
|
||||||
return message
|
|
||||||
content = message.content if isinstance(message.content, str) else ""
|
|
||||||
status = extract_subagent_status(content)
|
|
||||||
if status is None:
|
|
||||||
# Non-terminal streaming chunks or unrecognised shapes leave the
|
|
||||||
# field unset so the frontend can keep the card on its in-progress
|
|
||||||
# placeholder until a real terminal frame arrives.
|
|
||||||
return message
|
|
||||||
stamp = make_subagent_additional_kwargs(status, error=error)
|
|
||||||
existing = dict(message.additional_kwargs or {})
|
|
||||||
existing.update(stamp)
|
|
||||||
message.additional_kwargs = existing
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||||
@@ -67,31 +29,12 @@ class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
detail = detail[:497] + "..."
|
detail = detail[:497] + "..."
|
||||||
|
|
||||||
content = f"Error: Tool '{tool_name}' failed with {exc.__class__.__name__}: {detail}. Continue with available context, or choose an alternative tool."
|
content = f"Error: Tool '{tool_name}' failed with {exc.__class__.__name__}: {detail}. Continue with available context, or choose an alternative tool."
|
||||||
message = ToolMessage(
|
return ToolMessage(
|
||||||
content=content,
|
content=content,
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
name=tool_name,
|
name=tool_name,
|
||||||
status="error",
|
status="error",
|
||||||
)
|
)
|
||||||
# Stamp the structured subagent status on the wrapper too: the
|
|
||||||
# frontend would otherwise have to fall back to prefix-matching
|
|
||||||
# ``Error: Tool 'task' failed ...`` on the wire. The ``subagent_error``
|
|
||||||
# carries the same ``ExcClass: detail`` shape the wrapper string
|
|
||||||
# uses so debugging artifacts stay aligned.
|
|
||||||
structured_error = f"{exc.__class__.__name__}: {detail}"
|
|
||||||
return _stamp_task_subagent_status(message, tool_name=tool_name, error=structured_error)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _maybe_stamp(result: ToolMessage | Command, request: ToolCallRequest) -> ToolMessage | Command:
|
|
||||||
"""Apply the subagent stamp to successful task tool returns.
|
|
||||||
|
|
||||||
``Command`` results bypass the stamp — they encode LangGraph
|
|
||||||
control flow rather than user-facing tool output.
|
|
||||||
"""
|
|
||||||
if not isinstance(result, ToolMessage):
|
|
||||||
return result
|
|
||||||
tool_name = str(request.tool_call.get("name") or "")
|
|
||||||
return _stamp_task_subagent_status(result, tool_name=tool_name)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def wrap_tool_call(
|
def wrap_tool_call(
|
||||||
@@ -100,14 +43,13 @@ class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||||
) -> ToolMessage | Command:
|
) -> ToolMessage | Command:
|
||||||
try:
|
try:
|
||||||
result = handler(request)
|
return handler(request)
|
||||||
except GraphBubbleUp:
|
except GraphBubbleUp:
|
||||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Tool execution failed (sync): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
|
logger.exception("Tool execution failed (sync): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
|
||||||
return self._build_error_message(request, exc)
|
return self._build_error_message(request, exc)
|
||||||
return self._maybe_stamp(result, request)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def awrap_tool_call(
|
async def awrap_tool_call(
|
||||||
@@ -116,14 +58,13 @@ class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
|||||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||||
) -> ToolMessage | Command:
|
) -> ToolMessage | Command:
|
||||||
try:
|
try:
|
||||||
result = await handler(request)
|
return await handler(request)
|
||||||
except GraphBubbleUp:
|
except GraphBubbleUp:
|
||||||
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
# Preserve LangGraph control-flow signals (interrupt/pause/resume).
|
||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Tool execution failed (async): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
|
logger.exception("Tool execution failed (async): name=%s id=%s", request.tool_call.get("name"), request.tool_call.get("id"))
|
||||||
return self._build_error_message(request, exc)
|
return self._build_error_message(request, exc)
|
||||||
return self._maybe_stamp(result, request)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_runtime_middlewares(
|
def _build_runtime_middlewares(
|
||||||
@@ -202,7 +143,6 @@ def build_subagent_runtime_middlewares(
|
|||||||
app_config: AppConfig | None = None,
|
app_config: AppConfig | None = None,
|
||||||
model_name: str | None = None,
|
model_name: str | None = None,
|
||||||
lazy_init: bool = True,
|
lazy_init: bool = True,
|
||||||
deferred_setup: "DeferredToolSetup | None" = None,
|
|
||||||
) -> list[AgentMiddleware]:
|
) -> list[AgentMiddleware]:
|
||||||
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
||||||
if app_config is None:
|
if app_config is None:
|
||||||
@@ -226,16 +166,6 @@ def build_subagent_runtime_middlewares(
|
|||||||
|
|
||||||
middlewares.append(ViewImageMiddleware())
|
middlewares.append(ViewImageMiddleware())
|
||||||
|
|
||||||
# Hide deferred (MCP) tool schemas from the subagent's model binding until
|
|
||||||
# tool_search promotes them. This is the same wiring the lead agent gets. The deferred
|
|
||||||
# set + catalog hash come from the build-time setup (assembled after
|
|
||||||
# tool-policy filtering); promotion is read from graph state. Empty/None
|
|
||||||
# setup (deferral disabled or no MCP tool survived) is a pure no-op.
|
|
||||||
if deferred_setup is not None and deferred_setup.deferred_names:
|
|
||||||
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
|
|
||||||
|
|
||||||
middlewares.append(DeferredToolFilterMiddleware(deferred_setup.deferred_names, deferred_setup.catalog_hash))
|
|
||||||
|
|
||||||
# Same provider safety-termination guard the lead agent uses — subagents
|
# Same provider safety-termination guard the lead agent uses — subagents
|
||||||
# are equally exposed to truncated tool_calls returned with
|
# are equally exposed to truncated tool_calls returned with
|
||||||
# finish_reason=content_filter (and friends), and the bad call would then
|
# finish_reason=content_filter (and friends), and the bad call would then
|
||||||
|
|||||||
+14
-168
@@ -11,11 +11,10 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shlex
|
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from dataclasses import replace as dc_replace
|
from dataclasses import replace as dc_replace
|
||||||
from typing import TYPE_CHECKING, Any, override
|
from typing import Any, override
|
||||||
|
|
||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
from langchain.agents.middleware import AgentMiddleware
|
from langchain.agents.middleware import AgentMiddleware
|
||||||
@@ -25,19 +24,9 @@ from langgraph.prebuilt.tool_node import ToolCallRequest
|
|||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
|
|
||||||
from deerflow.config.tool_output_config import ToolOutputConfig
|
from deerflow.config.tool_output_config import ToolOutputConfig
|
||||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from deerflow.sandbox.sandbox import Sandbox
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Virtual outputs root inside the sandbox. Host-mounted sandboxes map this to
|
|
||||||
# the thread outputs dir on the host; for non-mounted (remote) sandboxes the
|
|
||||||
# same path is written directly into the sandbox filesystem so the model's
|
|
||||||
# ``read_file`` tool can read it back (issue #3416).
|
|
||||||
_VIRTUAL_OUTPUTS_BASE = "/mnt/user-data/outputs"
|
|
||||||
|
|
||||||
|
|
||||||
def _default_config() -> ToolOutputConfig:
|
def _default_config() -> ToolOutputConfig:
|
||||||
return ToolOutputConfig()
|
return ToolOutputConfig()
|
||||||
@@ -105,18 +94,6 @@ def _sanitize_tool_name(name: str) -> str:
|
|||||||
return safe or "unknown"
|
return safe or "unknown"
|
||||||
|
|
||||||
|
|
||||||
def _build_externalized_filename(*, tool_name: str, tool_call_id: str) -> str:
|
|
||||||
"""Build the on-disk filename for an externalized tool output.
|
|
||||||
|
|
||||||
Shared by the host-disk and sandbox externalization paths so both
|
|
||||||
produce the identical naming scheme.
|
|
||||||
"""
|
|
||||||
safe_name = _sanitize_tool_name(tool_name)
|
|
||||||
ext = _EXT_MAP.get(tool_name, "txt")
|
|
||||||
short_id = uuid.uuid4().hex[:12]
|
|
||||||
return f"{safe_name}-{short_id}.{ext}"
|
|
||||||
|
|
||||||
|
|
||||||
def _externalize(
|
def _externalize(
|
||||||
content: str,
|
content: str,
|
||||||
*,
|
*,
|
||||||
@@ -134,7 +111,10 @@ def _externalize(
|
|||||||
except OSError:
|
except OSError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
filename = _build_externalized_filename(tool_name=tool_name, tool_call_id=tool_call_id)
|
safe_name = _sanitize_tool_name(tool_name)
|
||||||
|
ext = _EXT_MAP.get(tool_name, "txt")
|
||||||
|
short_id = uuid.uuid4().hex[:12]
|
||||||
|
filename = f"{safe_name}-{short_id}.{ext}"
|
||||||
filepath = os.path.join(storage_dir, filename)
|
filepath = os.path.join(storage_dir, filename)
|
||||||
|
|
||||||
if not os.path.abspath(filepath).startswith(os.path.abspath(storage_dir)):
|
if not os.path.abspath(filepath).startswith(os.path.abspath(storage_dir)):
|
||||||
@@ -146,56 +126,8 @@ def _externalize(
|
|||||||
except OSError:
|
except OSError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return f"{_VIRTUAL_OUTPUTS_BASE}/{storage_subdir}/{filename}"
|
virtual_base = "/mnt/user-data/outputs"
|
||||||
|
return f"{virtual_base}/{storage_subdir}/{filename}"
|
||||||
|
|
||||||
def _externalize_to_sandbox(
|
|
||||||
content: str,
|
|
||||||
*,
|
|
||||||
tool_name: str,
|
|
||||||
tool_call_id: str,
|
|
||||||
storage_subdir: str,
|
|
||||||
sandbox: Sandbox,
|
|
||||||
) -> str | None:
|
|
||||||
"""Write *content* into the sandbox filesystem and return the virtual path.
|
|
||||||
|
|
||||||
Used when the sandbox does not use thread-data mounts (e.g. a remote AIO
|
|
||||||
sandbox): the host-side :func:`_externalize` virtual path would not exist
|
|
||||||
inside the sandbox, so the model's ``read_file`` tool could not read it
|
|
||||||
back (issue #3416). Returns the same virtual-path contract on success, or
|
|
||||||
``None`` to signal the caller to fall back to inline truncation.
|
|
||||||
"""
|
|
||||||
if os.path.isabs(storage_subdir) or ".." in storage_subdir:
|
|
||||||
return None
|
|
||||||
filename = _build_externalized_filename(tool_name=tool_name, tool_call_id=tool_call_id)
|
|
||||||
virtual_dir = f"{_VIRTUAL_OUTPUTS_BASE}/{storage_subdir}"
|
|
||||||
virtual_path = f"{virtual_dir}/{filename}"
|
|
||||||
try:
|
|
||||||
# AIO sandbox write_file does NOT create parent directories, so create
|
|
||||||
# them explicitly before writing. execute_command returns its stdout
|
|
||||||
# verbatim (including an "Error: ..." string on failure) rather than
|
|
||||||
# raising, so we cannot rely on exception propagation here.
|
|
||||||
sandbox.execute_command(f"mkdir -p {shlex.quote(virtual_dir)}")
|
|
||||||
sandbox.write_file(virtual_path, content)
|
|
||||||
# Validate the file landed: execute_command may have silently failed
|
|
||||||
# to create the directory, and write_file backends differ. Refuse to
|
|
||||||
# hand the model an unreadable read_file path.
|
|
||||||
check = sandbox.execute_command(f"test -s {shlex.quote(virtual_path)} && echo OK || echo MISSING")
|
|
||||||
if not isinstance(check, str) or check.strip() != "OK":
|
|
||||||
logger.warning(
|
|
||||||
"Sandbox externalize validation failed: path=%s, check=%r",
|
|
||||||
virtual_path,
|
|
||||||
check,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
"Failed to externalize %s output to sandbox (call_id=%s)",
|
|
||||||
tool_name,
|
|
||||||
tool_call_id,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
return virtual_path
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -295,33 +227,6 @@ def _resolve_outputs_path(request: ToolCallRequest) -> str | None:
|
|||||||
return outputs_path if isinstance(outputs_path, str) else None
|
return outputs_path if isinstance(outputs_path, str) else None
|
||||||
|
|
||||||
|
|
||||||
def _resolve_sandbox(request: ToolCallRequest) -> Sandbox | None:
|
|
||||||
"""Resolve the active sandbox for the current tool call, or ``None``.
|
|
||||||
|
|
||||||
Reads the sandbox_id that ``SandboxMiddleware`` (and the sandbox tools
|
|
||||||
themselves) write into ``runtime.state["sandbox"]``. We intentionally do
|
|
||||||
NOT call ``provider.acquire`` here: acquiring a sandbox can trigger
|
|
||||||
blocking remote I/O, and this resolver runs on every tool call. Tools
|
|
||||||
that do not use a sandbox (``web_search``, MCP, ...) will return ``None``
|
|
||||||
here, which is fine -- the caller falls back to inline truncation.
|
|
||||||
"""
|
|
||||||
runtime = getattr(request, "runtime", None)
|
|
||||||
state = getattr(runtime, "state", None)
|
|
||||||
if not isinstance(state, dict):
|
|
||||||
return None
|
|
||||||
sandbox_state = state.get("sandbox")
|
|
||||||
if not isinstance(sandbox_state, dict):
|
|
||||||
return None
|
|
||||||
sandbox_id = sandbox_state.get("sandbox_id")
|
|
||||||
if not sandbox_id:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
return get_sandbox_provider().get(sandbox_id)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to look up sandbox %s for tool-output externalization", sandbox_id)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _budget_content(
|
def _budget_content(
|
||||||
content: str,
|
content: str,
|
||||||
*,
|
*,
|
||||||
@@ -329,7 +234,6 @@ def _budget_content(
|
|||||||
tool_call_id: str,
|
tool_call_id: str,
|
||||||
outputs_path: str | None,
|
outputs_path: str | None,
|
||||||
config: ToolOutputConfig,
|
config: ToolOutputConfig,
|
||||||
sandbox: Sandbox | None = None,
|
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""Apply budget to *content*. Returns ``None`` if no change needed."""
|
"""Apply budget to *content*. Returns ``None`` if no change needed."""
|
||||||
threshold = config.tool_overrides.get(tool_name, config.externalize_min_chars)
|
threshold = config.tool_overrides.get(tool_name, config.externalize_min_chars)
|
||||||
@@ -338,43 +242,7 @@ def _budget_content(
|
|||||||
if len(content) <= threshold and len(content) <= config.fallback_max_chars:
|
if len(content) <= threshold and len(content) <= config.fallback_max_chars:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if threshold > 0 and len(content) > threshold:
|
if threshold > 0 and len(content) > threshold and outputs_path:
|
||||||
virtual_path: str | None = None
|
|
||||||
# Decide persistence target based on what's available, without touching
|
|
||||||
# the sandbox provider unless a sandbox was actually resolved for this
|
|
||||||
# call. This keeps the legacy host-disk path provider-free, so callers
|
|
||||||
# without a configured sandbox (and CI environments without a
|
|
||||||
# config.yaml) continue to externalize to the host as before.
|
|
||||||
if sandbox is not None:
|
|
||||||
provider = None
|
|
||||||
try:
|
|
||||||
provider = get_sandbox_provider()
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to get sandbox provider for tool-output externalization; falling back to inline truncation")
|
|
||||||
if provider is not None and getattr(provider, "uses_thread_data_mounts", False):
|
|
||||||
# Host-mounted sandbox: host outputs path is bind-mounted into
|
|
||||||
# the sandbox at the same virtual path, so writing host-side is
|
|
||||||
# equivalent. Preserve the original behavior to avoid extra
|
|
||||||
# sandbox round-trips.
|
|
||||||
if outputs_path:
|
|
||||||
virtual_path = _externalize(
|
|
||||||
content,
|
|
||||||
tool_name=tool_name,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
outputs_path=outputs_path,
|
|
||||||
storage_subdir=config.storage_subdir,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
virtual_path = _externalize_to_sandbox(
|
|
||||||
content,
|
|
||||||
tool_name=tool_name,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
storage_subdir=config.storage_subdir,
|
|
||||||
sandbox=sandbox,
|
|
||||||
)
|
|
||||||
elif outputs_path:
|
|
||||||
# No sandbox in this call (legacy / non-sandbox tools): write to
|
|
||||||
# host outputs path directly, no provider needed.
|
|
||||||
virtual_path = _externalize(
|
virtual_path = _externalize(
|
||||||
content,
|
content,
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
@@ -420,12 +288,7 @@ def _budget_content(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _patch_tool_message(
|
def _patch_tool_message(msg: ToolMessage, config: ToolOutputConfig, outputs_path: str | None) -> ToolMessage:
|
||||||
msg: ToolMessage,
|
|
||||||
config: ToolOutputConfig,
|
|
||||||
outputs_path: str | None,
|
|
||||||
sandbox: Sandbox | None = None,
|
|
||||||
) -> ToolMessage:
|
|
||||||
"""Apply budget to a single ToolMessage. Returns the original if unchanged."""
|
"""Apply budget to a single ToolMessage. Returns the original if unchanged."""
|
||||||
tool_name = msg.name or "unknown"
|
tool_name = msg.name or "unknown"
|
||||||
if tool_name in config.exempt_tools:
|
if tool_name in config.exempt_tools:
|
||||||
@@ -441,7 +304,6 @@ def _patch_tool_message(
|
|||||||
tool_call_id=msg.tool_call_id or "",
|
tool_call_id=msg.tool_call_id or "",
|
||||||
outputs_path=outputs_path,
|
outputs_path=outputs_path,
|
||||||
config=config,
|
config=config,
|
||||||
sandbox=sandbox,
|
|
||||||
)
|
)
|
||||||
if replacement is None:
|
if replacement is None:
|
||||||
return msg
|
return msg
|
||||||
@@ -493,15 +355,10 @@ def _needs_budget(result: ToolMessage | Command, config: ToolOutputConfig) -> bo
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _patch_result(
|
def _patch_result(result: ToolMessage | Command, config: ToolOutputConfig, outputs_path: str | None) -> ToolMessage | Command:
|
||||||
result: ToolMessage | Command,
|
|
||||||
config: ToolOutputConfig,
|
|
||||||
outputs_path: str | None,
|
|
||||||
sandbox: Sandbox | None = None,
|
|
||||||
) -> ToolMessage | Command:
|
|
||||||
"""Apply budget to a tool call result (ToolMessage or Command)."""
|
"""Apply budget to a tool call result (ToolMessage or Command)."""
|
||||||
if isinstance(result, ToolMessage):
|
if isinstance(result, ToolMessage):
|
||||||
return _patch_tool_message(result, config, outputs_path, sandbox)
|
return _patch_tool_message(result, config, outputs_path)
|
||||||
|
|
||||||
update = getattr(result, "update", None)
|
update = getattr(result, "update", None)
|
||||||
if not isinstance(update, dict):
|
if not isinstance(update, dict):
|
||||||
@@ -515,7 +372,7 @@ def _patch_result(
|
|||||||
changed = False
|
changed = False
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if isinstance(msg, ToolMessage):
|
if isinstance(msg, ToolMessage):
|
||||||
patched = _patch_tool_message(msg, config, outputs_path, sandbox)
|
patched = _patch_tool_message(msg, config, outputs_path)
|
||||||
if patched is not msg:
|
if patched is not msg:
|
||||||
changed = True
|
changed = True
|
||||||
new_messages.append(patched)
|
new_messages.append(patched)
|
||||||
@@ -535,11 +392,6 @@ def _patch_model_messages(messages: list[Any], config: ToolOutputConfig) -> list
|
|||||||
ToolMessage exceeds the budget — the common case once every result has
|
ToolMessage exceeds the budget — the common case once every result has
|
||||||
already been budgeted at tool-call time, so a long history is not rebuilt
|
already been budgeted at tool-call time, so a long history is not rebuilt
|
||||||
on every model call.
|
on every model call.
|
||||||
|
|
||||||
Historical messages do not get a ``sandbox`` argument: any oversized tool
|
|
||||||
message in history was already budgeted (and possibly externalized) at
|
|
||||||
tool-call time, so the only thing left for the history path to do is
|
|
||||||
inline fallback truncation, which needs no sandbox.
|
|
||||||
"""
|
"""
|
||||||
if not any(isinstance(msg, ToolMessage) and _tool_message_over_budget(msg, config) for msg in messages):
|
if not any(isinstance(msg, ToolMessage) and _tool_message_over_budget(msg, config) for msg in messages):
|
||||||
return None
|
return None
|
||||||
@@ -590,8 +442,7 @@ class ToolOutputBudgetMiddleware(AgentMiddleware[AgentState]):
|
|||||||
if not _needs_budget(result, self._config):
|
if not _needs_budget(result, self._config):
|
||||||
return result
|
return result
|
||||||
outputs_path = _resolve_outputs_path(request)
|
outputs_path = _resolve_outputs_path(request)
|
||||||
sandbox = _resolve_sandbox(request)
|
return _patch_result(result, self._config, outputs_path)
|
||||||
return _patch_result(result, self._config, outputs_path, sandbox)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def awrap_tool_call(
|
async def awrap_tool_call(
|
||||||
@@ -605,12 +456,7 @@ class ToolOutputBudgetMiddleware(AgentMiddleware[AgentState]):
|
|||||||
if not _needs_budget(result, self._config):
|
if not _needs_budget(result, self._config):
|
||||||
return result
|
return result
|
||||||
outputs_path = _resolve_outputs_path(request)
|
outputs_path = _resolve_outputs_path(request)
|
||||||
# _resolve_sandbox only touches runtime.state and the provider's
|
return await asyncio.to_thread(_patch_result, result, self._config, outputs_path)
|
||||||
# in-memory sandbox registry, so it is safe to call on the event
|
|
||||||
# loop. The actual sandbox I/O (mkdir/write/test) happens inside
|
|
||||||
# _patch_result, which is offloaded to a worker thread below.
|
|
||||||
sandbox = _resolve_sandbox(request)
|
|
||||||
return await asyncio.to_thread(_patch_result, result, self._config, outputs_path, sandbox)
|
|
||||||
|
|
||||||
# -- model call hooks (historical message truncation) ------------------
|
# -- model call hooks (historical message truncation) ------------------
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from langgraph.runtime import Runtime
|
|||||||
from deerflow.config.paths import Paths, get_paths
|
from deerflow.config.paths import Paths, get_paths
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
from deerflow.utils.file_conversion import extract_outline
|
from deerflow.utils.file_conversion import extract_outline
|
||||||
from deerflow.utils.messages import ORIGINAL_USER_CONTENT_KEY, message_content_to_text
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -266,8 +265,6 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
|||||||
|
|
||||||
# Extract original content - handle both string and list formats
|
# Extract original content - handle both string and list formats
|
||||||
original_content = last_message.content
|
original_content = last_message.content
|
||||||
additional_kwargs = dict(last_message.additional_kwargs or {})
|
|
||||||
additional_kwargs.setdefault(ORIGINAL_USER_CONTENT_KEY, message_content_to_text(original_content))
|
|
||||||
if isinstance(original_content, str):
|
if isinstance(original_content, str):
|
||||||
# Simple case: string content, just prepend files message
|
# Simple case: string content, just prepend files message
|
||||||
updated_content = f"{files_message}\n\n{original_content}"
|
updated_content = f"{files_message}\n\n{original_content}"
|
||||||
@@ -288,7 +285,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
|||||||
content=updated_content,
|
content=updated_content,
|
||||||
id=last_message.id,
|
id=last_message.id,
|
||||||
name=last_message.name,
|
name=last_message.name,
|
||||||
additional_kwargs=additional_kwargs,
|
additional_kwargs=last_message.additional_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages[last_message_index] = updated_message
|
messages[last_message_index] = updated_message
|
||||||
|
|||||||
@@ -179,10 +179,8 @@ class ViewImageMiddleware(AgentMiddleware[ViewImageMiddlewareState]):
|
|||||||
# Create the image details message with text and image content
|
# Create the image details message with text and image content
|
||||||
image_content = self._create_image_details_message(state)
|
image_content = self._create_image_details_message(state)
|
||||||
|
|
||||||
# Create a new human message with mixed content (text + images). This is
|
# Create a new human message with mixed content (text + images)
|
||||||
# internal context for the model only, so hide it from the chat UI and IM
|
human_msg = HumanMessage(content=image_content)
|
||||||
# channels (matches the other middleware-injected context messages).
|
|
||||||
human_msg = HumanMessage(content=image_content, additional_kwargs={"hide_from_ui": True})
|
|
||||||
|
|
||||||
logger.debug("Injecting image details message with images before LLM call")
|
logger.debug("Injecting image details message with images before LLM call")
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from langchain.agents.middleware import AgentMiddleware
|
|||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
from deerflow.agents.lead_agent.agent import build_middlewares
|
from deerflow.agents.lead_agent.agent import _assemble_deferred, _build_middlewares
|
||||||
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
||||||
from deerflow.agents.thread_state import ThreadState
|
from deerflow.agents.thread_state import ThreadState
|
||||||
from deerflow.config.agents_config import AGENT_NAME_PATTERN
|
from deerflow.config.agents_config import AGENT_NAME_PATTERN
|
||||||
@@ -43,7 +43,6 @@ from deerflow.config.paths import get_paths
|
|||||||
from deerflow.models import create_chat_model
|
from deerflow.models import create_chat_model
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
from deerflow.skills.storage import get_or_new_skill_storage
|
from deerflow.skills.storage import get_or_new_skill_storage
|
||||||
from deerflow.tools.builtins.tool_search import assemble_deferred_tools
|
|
||||||
from deerflow.tracing import build_tracing_callbacks, inject_langfuse_metadata
|
from deerflow.tracing import build_tracing_callbacks, inject_langfuse_metadata
|
||||||
from deerflow.uploads.manager import (
|
from deerflow.uploads.manager import (
|
||||||
claim_unique_filename,
|
claim_unique_filename,
|
||||||
@@ -239,7 +238,7 @@ class DeerFlowClient:
|
|||||||
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
|
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
|
||||||
|
|
||||||
tools = self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled)
|
tools = self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled)
|
||||||
final_tools, deferred_setup = assemble_deferred_tools(tools, enabled=self._app_config.tool_search.enabled)
|
final_tools, deferred_setup = _assemble_deferred(tools, enabled=self._app_config.tool_search.enabled)
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
# attach_tracing=False because ``stream()`` injects tracing
|
# attach_tracing=False because ``stream()`` injects tracing
|
||||||
# callbacks at the graph invocation root so a single embedded run
|
# callbacks at the graph invocation root so a single embedded run
|
||||||
@@ -247,15 +246,7 @@ class DeerFlowClient:
|
|||||||
# Attaching them again on the model would emit duplicate spans.
|
# Attaching them again on the model would emit duplicate spans.
|
||||||
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled, attach_tracing=False),
|
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled, attach_tracing=False),
|
||||||
"tools": final_tools,
|
"tools": final_tools,
|
||||||
"middleware": build_middlewares(
|
"middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares, deferred_setup=deferred_setup),
|
||||||
config,
|
|
||||||
model_name=model_name,
|
|
||||||
agent_name=self._agent_name,
|
|
||||||
available_skills=self._available_skills,
|
|
||||||
custom_middlewares=self._middlewares,
|
|
||||||
app_config=self._app_config,
|
|
||||||
deferred_setup=deferred_setup,
|
|
||||||
),
|
|
||||||
"system_prompt": apply_prompt_template(
|
"system_prompt": apply_prompt_template(
|
||||||
subagent_enabled=subagent_enabled,
|
subagent_enabled=subagent_enabled,
|
||||||
max_concurrent_subagents=max_concurrent_subagents,
|
max_concurrent_subagents=max_concurrent_subagents,
|
||||||
@@ -1141,7 +1132,6 @@ class DeerFlowClient:
|
|||||||
"fact_confidence_threshold": config.fact_confidence_threshold,
|
"fact_confidence_threshold": config.fact_confidence_threshold,
|
||||||
"injection_enabled": config.injection_enabled,
|
"injection_enabled": config.injection_enabled,
|
||||||
"max_injection_tokens": config.max_injection_tokens,
|
"max_injection_tokens": config.max_injection_tokens,
|
||||||
"token_counting": config.token_counting,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_memory_status(self) -> dict:
|
def get_memory_status(self) -> dict:
|
||||||
|
|||||||
@@ -470,32 +470,14 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
|
|
||||||
existing_id = self._thread_sandboxes[thread_id]
|
existing_id = self._thread_sandboxes[thread_id]
|
||||||
if existing_id in self._sandboxes:
|
if existing_id in self._sandboxes:
|
||||||
info = self._sandbox_infos.get(existing_id)
|
|
||||||
else:
|
|
||||||
del self._thread_sandboxes[thread_id]
|
|
||||||
return None
|
|
||||||
|
|
||||||
alive = self._check_tracked_sandbox_alive(existing_id, info) if info is not None else True
|
|
||||||
if alive is False:
|
|
||||||
self._drop_unhealthy_sandbox(
|
|
||||||
existing_id,
|
|
||||||
"in-process cache failed health check",
|
|
||||||
expected_info=info,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
if self._thread_sandboxes.get(thread_id) != existing_id:
|
|
||||||
return None
|
|
||||||
if existing_id not in self._sandboxes:
|
|
||||||
self._thread_sandboxes.pop(thread_id, None)
|
|
||||||
return None
|
|
||||||
|
|
||||||
suffix = " (post-lock check)" if post_lock else ""
|
suffix = " (post-lock check)" if post_lock else ""
|
||||||
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}{suffix}")
|
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}{suffix}")
|
||||||
self._last_activity[existing_id] = time.time()
|
self._last_activity[existing_id] = time.time()
|
||||||
return existing_id
|
return existing_id
|
||||||
|
|
||||||
|
del self._thread_sandboxes[thread_id]
|
||||||
|
return None
|
||||||
|
|
||||||
def _reclaim_warm_pool_sandbox(self, thread_id: str | None, sandbox_id: str, *, post_lock: bool = False) -> str | None:
|
def _reclaim_warm_pool_sandbox(self, thread_id: str | None, sandbox_id: str, *, post_lock: bool = False) -> str | None:
|
||||||
"""Promote a warm-pool sandbox back to active tracking if available."""
|
"""Promote a warm-pool sandbox back to active tracking if available."""
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
@@ -505,22 +487,7 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
if sandbox_id not in self._warm_pool:
|
if sandbox_id not in self._warm_pool:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
info, _ = self._warm_pool[sandbox_id]
|
info, _ = self._warm_pool.pop(sandbox_id)
|
||||||
|
|
||||||
alive = self._check_tracked_sandbox_alive(sandbox_id, info)
|
|
||||||
if alive is False:
|
|
||||||
self._drop_unhealthy_sandbox(
|
|
||||||
sandbox_id,
|
|
||||||
"warm-pool cache failed health check",
|
|
||||||
expected_info=info,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
warm_item = self._warm_pool.pop(sandbox_id, None)
|
|
||||||
if warm_item is None:
|
|
||||||
return None
|
|
||||||
info, _ = warm_item
|
|
||||||
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
|
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
|
||||||
self._sandboxes[sandbox_id] = sandbox
|
self._sandboxes[sandbox_id] = sandbox
|
||||||
self._sandbox_infos[sandbox_id] = info
|
self._sandbox_infos[sandbox_id] = info
|
||||||
@@ -560,70 +527,6 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
logger.info(f"Created sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}")
|
logger.info(f"Created sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}")
|
||||||
return sandbox_id
|
return sandbox_id
|
||||||
|
|
||||||
def _check_tracked_sandbox_alive(self, sandbox_id: str, info: SandboxInfo) -> bool | None:
|
|
||||||
"""Return whether a tracked sandbox appears alive, or None if unknown."""
|
|
||||||
try:
|
|
||||||
return self._backend.is_alive(info)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to check sandbox {sandbox_id} health: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _remove_tracked_sandbox(
|
|
||||||
self,
|
|
||||||
sandbox_id: str,
|
|
||||||
*,
|
|
||||||
expected_info: SandboxInfo | None = None,
|
|
||||||
) -> tuple[Sandbox | None, SandboxInfo | None, bool]:
|
|
||||||
"""Remove a sandbox from in-process tracking maps.
|
|
||||||
|
|
||||||
When expected_info is provided, removal only happens if the currently
|
|
||||||
tracked active or warm-pool entry is the exact info object that was
|
|
||||||
checked. This prevents a stale health-check result from deleting a
|
|
||||||
freshly recreated sandbox with the same deterministic id.
|
|
||||||
"""
|
|
||||||
thread_ids_to_remove: list[str] = []
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
active_info = self._sandbox_infos.get(sandbox_id)
|
|
||||||
warm_item = self._warm_pool.get(sandbox_id)
|
|
||||||
warm_info = warm_item[0] if warm_item is not None else None
|
|
||||||
if expected_info is not None and active_info is not expected_info and warm_info is not expected_info:
|
|
||||||
return None, None, False
|
|
||||||
|
|
||||||
sandbox = self._sandboxes.pop(sandbox_id, None)
|
|
||||||
info = self._sandbox_infos.pop(sandbox_id, None)
|
|
||||||
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
|
|
||||||
for tid in thread_ids_to_remove:
|
|
||||||
del self._thread_sandboxes[tid]
|
|
||||||
self._last_activity.pop(sandbox_id, None)
|
|
||||||
if info is None and sandbox_id in self._warm_pool:
|
|
||||||
info, _ = self._warm_pool.pop(sandbox_id)
|
|
||||||
else:
|
|
||||||
self._warm_pool.pop(sandbox_id, None)
|
|
||||||
|
|
||||||
return sandbox, info, True
|
|
||||||
|
|
||||||
def _drop_unhealthy_sandbox(self, sandbox_id: str, reason: str, *, expected_info: SandboxInfo | None = None) -> None:
|
|
||||||
"""Remove and destroy a sandbox after a definitive failed health check."""
|
|
||||||
sandbox, info, removed = self._remove_tracked_sandbox(sandbox_id, expected_info=expected_info)
|
|
||||||
if not removed:
|
|
||||||
logger.info(f"Skipped dropping sandbox {sandbox_id}: tracked info changed after health check")
|
|
||||||
return
|
|
||||||
|
|
||||||
if sandbox is not None:
|
|
||||||
try:
|
|
||||||
sandbox.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error closing unhealthy sandbox {sandbox_id}: {e}")
|
|
||||||
|
|
||||||
if info is not None:
|
|
||||||
try:
|
|
||||||
self._backend.destroy(info)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error destroying unhealthy sandbox {sandbox_id}: {e}")
|
|
||||||
|
|
||||||
logger.warning(f"Dropped unhealthy sandbox {sandbox_id}: {reason}")
|
|
||||||
|
|
||||||
def _replica_count(self) -> tuple[int, int]:
|
def _replica_count(self) -> tuple[int, int]:
|
||||||
"""Return configured replicas and currently tracked sandbox count."""
|
"""Return configured replicas and currently tracked sandbox count."""
|
||||||
replicas = self._config.get("replicas", DEFAULT_REPLICAS)
|
replicas = self._config.get("replicas", DEFAULT_REPLICAS)
|
||||||
@@ -714,7 +617,7 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
|
|
||||||
async def _acquire_internal_async(self, thread_id: str | None) -> str:
|
async def _acquire_internal_async(self, thread_id: str | None) -> str:
|
||||||
"""Async counterpart to ``_acquire_internal``."""
|
"""Async counterpart to ``_acquire_internal``."""
|
||||||
cached_id = await asyncio.to_thread(self._reuse_in_process_sandbox, thread_id)
|
cached_id = self._reuse_in_process_sandbox(thread_id)
|
||||||
if cached_id is not None:
|
if cached_id is not None:
|
||||||
return cached_id
|
return cached_id
|
||||||
|
|
||||||
@@ -722,7 +625,7 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
sandbox_id = self._sandbox_id_for_thread(thread_id)
|
sandbox_id = self._sandbox_id_for_thread(thread_id)
|
||||||
|
|
||||||
# ── Layer 1.5: Warm pool (container still running, no cold-start) ──
|
# ── Layer 1.5: Warm pool (container still running, no cold-start) ──
|
||||||
reclaimed_id = await asyncio.to_thread(self._reclaim_warm_pool_sandbox, thread_id, sandbox_id)
|
reclaimed_id = self._reclaim_warm_pool_sandbox(thread_id, sandbox_id)
|
||||||
if reclaimed_id is not None:
|
if reclaimed_id is not None:
|
||||||
return reclaimed_id
|
return reclaimed_id
|
||||||
|
|
||||||
@@ -778,7 +681,7 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
locked = True
|
locked = True
|
||||||
# Re-check in-process caches under the file lock in case another
|
# Re-check in-process caches under the file lock in case another
|
||||||
# thread in this process won the race while we were waiting.
|
# thread in this process won the race while we were waiting.
|
||||||
cached_id = await asyncio.to_thread(self._recheck_cached_sandbox, thread_id, sandbox_id)
|
cached_id = self._recheck_cached_sandbox(thread_id, sandbox_id)
|
||||||
if cached_id is not None:
|
if cached_id is not None:
|
||||||
return cached_id
|
return cached_id
|
||||||
|
|
||||||
@@ -934,7 +837,22 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
Args:
|
Args:
|
||||||
sandbox_id: The ID of the sandbox to destroy.
|
sandbox_id: The ID of the sandbox to destroy.
|
||||||
"""
|
"""
|
||||||
sandbox, info, _ = self._remove_tracked_sandbox(sandbox_id)
|
info = None
|
||||||
|
sandbox = None
|
||||||
|
thread_ids_to_remove: list[str] = []
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
sandbox = self._sandboxes.pop(sandbox_id, None)
|
||||||
|
info = self._sandbox_infos.pop(sandbox_id, None)
|
||||||
|
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
|
||||||
|
for tid in thread_ids_to_remove:
|
||||||
|
del self._thread_sandboxes[tid]
|
||||||
|
self._last_activity.pop(sandbox_id, None)
|
||||||
|
# Also pull from warm pool if it was parked there
|
||||||
|
if info is None and sandbox_id in self._warm_pool:
|
||||||
|
info, _ = self._warm_pool.pop(sandbox_id)
|
||||||
|
else:
|
||||||
|
self._warm_pool.pop(sandbox_id, None)
|
||||||
|
|
||||||
if sandbox is not None:
|
if sandbox is not None:
|
||||||
# Defense-in-depth: close() already swallows its own errors; this
|
# Defense-in-depth: close() already swallows its own errors; this
|
||||||
|
|||||||
@@ -169,24 +169,6 @@ def _resolve_docker_bind_host(sandbox_host: str | None = None, bind_host: str |
|
|||||||
return "0.0.0.0"
|
return "0.0.0.0"
|
||||||
|
|
||||||
|
|
||||||
def _is_no_such_container_error(stderr: str, container_name: str) -> bool:
|
|
||||||
"""Return True only when stderr definitively says the container does not exist.
|
|
||||||
|
|
||||||
Docker reports "No such object" / "No such container". Apple Container
|
|
||||||
reports a generic "not found", so that phrase is only trusted when the
|
|
||||||
message also names the inspected container (or refers to a
|
|
||||||
container/object); transient failures whose text happens to contain
|
|
||||||
"not found" (e.g. "command not found", "context not found") must stay on
|
|
||||||
the raise path instead of being misread as a dead container.
|
|
||||||
"""
|
|
||||||
message = stderr.lower()
|
|
||||||
if "no such object" in message or "no such container" in message:
|
|
||||||
return True
|
|
||||||
if "not found" not in message:
|
|
||||||
return False
|
|
||||||
return container_name.lower() in message or "container" in message or "object" in message
|
|
||||||
|
|
||||||
|
|
||||||
class LocalContainerBackend(SandboxBackend):
|
class LocalContainerBackend(SandboxBackend):
|
||||||
"""Backend that manages sandbox containers locally using Docker or Apple Container.
|
"""Backend that manages sandbox containers locally using Docker or Apple Container.
|
||||||
|
|
||||||
@@ -353,21 +335,11 @@ class LocalContainerBackend(SandboxBackend):
|
|||||||
sandbox_id: The deterministic sandbox ID (determines container name).
|
sandbox_id: The deterministic sandbox ID (determines container name).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SandboxInfo if container found and healthy, None otherwise. A
|
SandboxInfo if container found and healthy, None otherwise.
|
||||||
failed runtime check (e.g. transient daemon error) also returns
|
|
||||||
None — discovery must not adopt a container it cannot verify, and
|
|
||||||
falling through to create keeps acquire recoverable instead of
|
|
||||||
hard-failing on a hiccup.
|
|
||||||
"""
|
"""
|
||||||
container_name = f"{self._container_prefix}-{sandbox_id}"
|
container_name = f"{self._container_prefix}-{sandbox_id}"
|
||||||
|
|
||||||
try:
|
if not self._is_container_running(container_name):
|
||||||
running = self._is_container_running(container_name)
|
|
||||||
except RuntimeError as e:
|
|
||||||
logger.warning(f"Could not verify container {container_name} during discovery; not adopting it: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not running:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
port = self._get_container_port(container_name)
|
port = self._get_container_port(container_name)
|
||||||
@@ -610,13 +582,6 @@ class LocalContainerBackend(SandboxBackend):
|
|||||||
|
|
||||||
This enables cross-process container discovery — any process can detect
|
This enables cross-process container discovery — any process can detect
|
||||||
containers started by another process via the deterministic container name.
|
containers started by another process via the deterministic container name.
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If the container runtime cannot answer the inspect
|
|
||||||
query. A failed check is intentionally distinct from a
|
|
||||||
definitive "container does not exist" result so callers do not
|
|
||||||
destroy healthy containers during transient Docker/Container
|
|
||||||
daemon failures.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
@@ -625,14 +590,9 @@ class LocalContainerBackend(SandboxBackend):
|
|||||||
text=True,
|
text=True,
|
||||||
timeout=5,
|
timeout=5,
|
||||||
)
|
)
|
||||||
except subprocess.TimeoutExpired as exc:
|
return result.returncode == 0 and result.stdout.strip().lower() == "true"
|
||||||
raise RuntimeError(f"Timed out checking container {container_name}") from exc
|
except (subprocess.CalledProcessError, subprocess.TimeoutExpired):
|
||||||
|
|
||||||
if result.returncode == 0:
|
|
||||||
return result.stdout.strip().lower() == "true"
|
|
||||||
if _is_no_such_container_error(result.stderr, container_name):
|
|
||||||
return False
|
return False
|
||||||
raise RuntimeError(f"Failed to inspect container {container_name}: {result.stderr.strip()}")
|
|
||||||
|
|
||||||
def _get_container_port(self, container_name: str) -> int | None:
|
def _get_container_port(self, container_name: str) -> int | None:
|
||||||
"""Get the host port of a running container.
|
"""Get the host port of a running container.
|
||||||
|
|||||||
@@ -176,16 +176,12 @@ class RemoteSandboxBackend(SandboxBackend):
|
|||||||
f"{self._provisioner_url}/api/sandboxes/{sandbox_id}",
|
f"{self._provisioner_url}/api/sandboxes/{sandbox_id}",
|
||||||
timeout=10,
|
timeout=10,
|
||||||
)
|
)
|
||||||
except requests.RequestException as exc:
|
if resp.ok:
|
||||||
raise RuntimeError(f"Provisioner health check failed for {sandbox_id}: {exc}") from exc
|
|
||||||
|
|
||||||
if resp.status_code == 404:
|
|
||||||
return False
|
|
||||||
if not resp.ok:
|
|
||||||
raise RuntimeError(f"Provisioner health check failed for {sandbox_id}: HTTP {resp.status_code} {resp.text}")
|
|
||||||
|
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
return data.get("status") == "Running"
|
return data.get("status") == "Running"
|
||||||
|
return False
|
||||||
|
except requests.RequestException:
|
||||||
|
return False
|
||||||
|
|
||||||
def _provisioner_discover(self, sandbox_id: str) -> SandboxInfo | None:
|
def _provisioner_discover(self, sandbox_id: str) -> SandboxInfo | None:
|
||||||
"""GET /api/sandboxes/{sandbox_id} → discover existing sandbox."""
|
"""GET /api/sandboxes/{sandbox_id} → discover existing sandbox."""
|
||||||
|
|||||||
@@ -1,4 +0,0 @@
|
|||||||
from .browserless_client import BrowserlessClient
|
|
||||||
from .tools import web_fetch_tool
|
|
||||||
|
|
||||||
__all__ = ["BrowserlessClient", "web_fetch_tool"]
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class BrowserlessClient:
|
|
||||||
"""Client for Browserless headless Chrome API."""
|
|
||||||
|
|
||||||
def __init__(self, base_url: str, token: str = "", timeout_s: float = 30) -> None:
|
|
||||||
self.base_url = base_url.rstrip("/")
|
|
||||||
self.token = token
|
|
||||||
self.timeout_s = timeout_s
|
|
||||||
|
|
||||||
async def fetch_html(
|
|
||||||
self,
|
|
||||||
url: str,
|
|
||||||
wait_for_event: str = "",
|
|
||||||
wait_for_timeout_ms: int = 0,
|
|
||||||
wait_for_selector: str = "",
|
|
||||||
wait_for_selector_timeout_ms: int = 5000,
|
|
||||||
reject_resource_types: list[str] | None = None,
|
|
||||||
reject_request_pattern: list[str] | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""Fetch the rendered HTML of a page using Browserless.
|
|
||||||
|
|
||||||
Only sends accepted parameters for the current Browserless API version.
|
|
||||||
Sets a default navigation timeout (30s) via query param.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url: The URL to fetch.
|
|
||||||
wait_for_event: Wait for a page event (e.g. "networkidle", "load").
|
|
||||||
wait_for_timeout_ms: Extra wait after page load.
|
|
||||||
wait_for_selector: CSS selector to wait for.
|
|
||||||
wait_for_selector_timeout_ms: Timeout for selector wait.
|
|
||||||
reject_resource_types: Resource types to block (e.g. ["image"]).
|
|
||||||
reject_request_pattern: URL patterns to block.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Rendered HTML content.
|
|
||||||
"""
|
|
||||||
payload: dict[str, Any] = {
|
|
||||||
"url": url,
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.token:
|
|
||||||
payload["token"] = self.token
|
|
||||||
if wait_for_event:
|
|
||||||
payload["waitForEvent"] = wait_for_event
|
|
||||||
if wait_for_timeout_ms > 0:
|
|
||||||
payload["waitForTimeout"] = wait_for_timeout_ms
|
|
||||||
if wait_for_selector:
|
|
||||||
payload["waitForSelector"] = {
|
|
||||||
"selector": wait_for_selector,
|
|
||||||
"timeout": wait_for_selector_timeout_ms,
|
|
||||||
}
|
|
||||||
if reject_resource_types:
|
|
||||||
payload["rejectResourceTypes"] = reject_resource_types
|
|
||||||
if reject_request_pattern:
|
|
||||||
payload["rejectRequestPattern"] = reject_request_pattern
|
|
||||||
|
|
||||||
logger.debug(f"Fetching URL via Browserless: {url}")
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=self.timeout_s) as client:
|
|
||||||
resp = await client.post(
|
|
||||||
f"{self.base_url}/content",
|
|
||||||
json=payload,
|
|
||||||
headers={
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
code = resp.status_code
|
|
||||||
target_code = resp.headers.get("X-Response-Code", "")
|
|
||||||
target_status = resp.headers.get("X-Response-Status", "")
|
|
||||||
|
|
||||||
logger.debug(f"Browserless response: code={code}, target_code={target_code}, target_status={target_status}")
|
|
||||||
|
|
||||||
if code != 200:
|
|
||||||
return f"Error: Browserless HTTP {code}: {resp.text[:200]}"
|
|
||||||
|
|
||||||
html = resp.text
|
|
||||||
if not html or not html.strip():
|
|
||||||
return "Error: Browserless returned empty response"
|
|
||||||
|
|
||||||
return html
|
|
||||||
|
|
||||||
except httpx.TimeoutException:
|
|
||||||
return f"Error: Browserless request timed out after {self.timeout_s}s"
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
logger.error(f"Browserless request failed: {e}")
|
|
||||||
return f"Error: Browserless request failed: {e!s}"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Browserless fetch failed: {e}")
|
|
||||||
return f"Error: Browserless fetch failed: {e!s}"
|
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from langchain.tools import tool
|
|
||||||
|
|
||||||
from deerflow.config import get_app_config
|
|
||||||
from deerflow.utils.readability import ReadabilityExtractor
|
|
||||||
|
|
||||||
from .browserless_client import BrowserlessClient
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# readability_extractor runs CPU-bound parsing; always call via asyncio.to_thread
|
|
||||||
_readability_extractor = ReadabilityExtractor()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_tool_config(tool_name: str) -> dict | None:
|
|
||||||
"""Get tool config extras safely, returning None if not configured."""
|
|
||||||
config = get_app_config().get_tool_config(tool_name)
|
|
||||||
if config is None:
|
|
||||||
return None
|
|
||||||
extras = config.model_extra
|
|
||||||
return extras if extras is not None else {}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_browserless_client() -> BrowserlessClient:
|
|
||||||
cfg = _get_tool_config("web_fetch")
|
|
||||||
base_url = "http://localhost:3032"
|
|
||||||
token = ""
|
|
||||||
timeout_s = 30.0
|
|
||||||
if cfg is not None:
|
|
||||||
base_url = cfg.get("base_url", base_url)
|
|
||||||
token = cfg.get("token", token)
|
|
||||||
raw = cfg.get("timeout_s", timeout_s)
|
|
||||||
timeout_s = float(raw) if not isinstance(raw, float) else raw
|
|
||||||
return BrowserlessClient(base_url=base_url, token=token, timeout_s=timeout_s)
|
|
||||||
|
|
||||||
|
|
||||||
@tool("web_fetch", parse_docstring=True)
|
|
||||||
async def web_fetch_tool(url: str) -> str:
|
|
||||||
"""Fetch the contents of a web page at a given URL using Browserless (headless Chrome).
|
|
||||||
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
|
|
||||||
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
|
|
||||||
Do NOT add www. to URLs that do NOT have them.
|
|
||||||
URLs must include the schema: https://example.com is a valid URL while example.com is an invalid URL.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url: The URL to fetch the contents of.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
cfg = _get_tool_config("web_fetch")
|
|
||||||
|
|
||||||
wait_for_event = ""
|
|
||||||
wait_for_timeout_ms = 0
|
|
||||||
wait_for_selector = ""
|
|
||||||
wait_for_selector_timeout_ms = 5000
|
|
||||||
reject_resource_types: list[str] | None = None
|
|
||||||
reject_request_pattern: list[str] | None = None
|
|
||||||
|
|
||||||
if cfg is not None:
|
|
||||||
wait_for_event = cfg.get("wait_for_event", wait_for_event)
|
|
||||||
raw_wait = cfg.get("wait_for_timeout_ms", wait_for_timeout_ms)
|
|
||||||
wait_for_timeout_ms = int(raw_wait) if not isinstance(raw_wait, int) else raw_wait
|
|
||||||
wait_for_selector = cfg.get("wait_for_selector", wait_for_selector)
|
|
||||||
|
|
||||||
client = _get_browserless_client()
|
|
||||||
html = await client.fetch_html(
|
|
||||||
url=url,
|
|
||||||
wait_for_event=wait_for_event,
|
|
||||||
wait_for_timeout_ms=wait_for_timeout_ms,
|
|
||||||
wait_for_selector=wait_for_selector,
|
|
||||||
wait_for_selector_timeout_ms=wait_for_selector_timeout_ms,
|
|
||||||
reject_resource_types=reject_resource_types,
|
|
||||||
reject_request_pattern=reject_request_pattern,
|
|
||||||
)
|
|
||||||
|
|
||||||
if html.startswith("Error:"):
|
|
||||||
return html
|
|
||||||
|
|
||||||
article = await asyncio.to_thread(_readability_extractor.extract_article, html)
|
|
||||||
return article.to_markdown()[:4096]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in web_fetch_tool: {e}")
|
|
||||||
return f"Error: {str(e)}"
|
|
||||||
@@ -11,85 +11,12 @@ from deerflow.config import get_app_config
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_BACKEND = "auto"
|
|
||||||
DEFAULT_REGION = "wt-wt"
|
|
||||||
DEFAULT_SAFESEARCH = "moderate"
|
|
||||||
DEFAULT_WIKIPEDIA_REGION = "us-en"
|
|
||||||
|
|
||||||
WIKIPEDIA_BACKENDS = {"auto", "all", "wikipedia"}
|
|
||||||
WIKIPEDIA_LANGUAGE_ALIASES = {
|
|
||||||
"jp": "ja",
|
|
||||||
"kr": "ko",
|
|
||||||
"tzh": "zh",
|
|
||||||
"wt": "en",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_backend(backend: str | list[str] | tuple[str, ...] | None) -> str:
|
|
||||||
if backend is None:
|
|
||||||
return DEFAULT_BACKEND
|
|
||||||
if isinstance(backend, (list, tuple)):
|
|
||||||
return ",".join(str(part).strip() for part in backend if str(part).strip()) or DEFAULT_BACKEND
|
|
||||||
return str(backend).strip() or DEFAULT_BACKEND
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_setting(value: str | None, default: str) -> str:
|
|
||||||
return str(value).strip() if value else default
|
|
||||||
|
|
||||||
|
|
||||||
def _backend_includes_wikipedia(backend: str | list[str] | tuple[str, ...] | None) -> bool:
|
|
||||||
backend = _normalize_backend(backend)
|
|
||||||
return any(part.strip().lower() in WIKIPEDIA_BACKENDS for part in backend.split(","))
|
|
||||||
|
|
||||||
|
|
||||||
def _contains_codepoint(query: str, ranges: tuple[tuple[int, int], ...]) -> bool:
|
|
||||||
return any(start <= ord(char) <= end for char in query for start, end in ranges)
|
|
||||||
|
|
||||||
|
|
||||||
def _infer_wikipedia_region(query: str) -> str:
|
|
||||||
"""Pick a valid Wikipedia language region when DDGS' worldwide region is used."""
|
|
||||||
if _contains_codepoint(query, ((0x3040, 0x30FF), (0x31F0, 0x31FF))):
|
|
||||||
return "jp-ja"
|
|
||||||
if _contains_codepoint(query, ((0xAC00, 0xD7AF), (0x1100, 0x11FF), (0x3130, 0x318F))):
|
|
||||||
return "kr-ko"
|
|
||||||
if _contains_codepoint(query, ((0x3400, 0x9FFF),)):
|
|
||||||
return "cn-zh"
|
|
||||||
if _contains_codepoint(query, ((0x0400, 0x04FF),)):
|
|
||||||
return "ru-ru"
|
|
||||||
if _contains_codepoint(query, ((0x0370, 0x03FF),)):
|
|
||||||
return "gr-el"
|
|
||||||
if _contains_codepoint(query, ((0x0590, 0x05FF),)):
|
|
||||||
return "il-he"
|
|
||||||
if _contains_codepoint(query, ((0x0600, 0x06FF),)):
|
|
||||||
return "xa-ar"
|
|
||||||
return DEFAULT_WIKIPEDIA_REGION
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_ddgs_region(query: str, region: str | None, backend: str | list[str] | tuple[str, ...] | None) -> str:
|
|
||||||
"""
|
|
||||||
DDGS' wikipedia engine treats the second part of region as a Wikipedia
|
|
||||||
subdomain. Its default worldwide region, wt-wt, becomes wt.wikipedia.org.
|
|
||||||
"""
|
|
||||||
normalized_region = _normalize_setting(region, DEFAULT_REGION).lower()
|
|
||||||
if not _backend_includes_wikipedia(backend):
|
|
||||||
return normalized_region
|
|
||||||
|
|
||||||
if normalized_region == DEFAULT_REGION:
|
|
||||||
return _infer_wikipedia_region(query)
|
|
||||||
|
|
||||||
if "-" not in normalized_region:
|
|
||||||
return DEFAULT_WIKIPEDIA_REGION
|
|
||||||
|
|
||||||
country, language = normalized_region.split("-", 1)
|
|
||||||
return f"{country}-{WIKIPEDIA_LANGUAGE_ALIASES.get(language, language)}"
|
|
||||||
|
|
||||||
|
|
||||||
def _search_text(
|
def _search_text(
|
||||||
query: str,
|
query: str,
|
||||||
max_results: int = 5,
|
max_results: int = 5,
|
||||||
region: str | None = DEFAULT_REGION,
|
region: str = "wt-wt",
|
||||||
safesearch: str | None = DEFAULT_SAFESEARCH,
|
safesearch: str = "moderate",
|
||||||
backend: str | list[str] | tuple[str, ...] | None = DEFAULT_BACKEND,
|
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Execute text search using DuckDuckGo.
|
Execute text search using DuckDuckGo.
|
||||||
@@ -99,7 +26,6 @@ def _search_text(
|
|||||||
max_results: Maximum number of results
|
max_results: Maximum number of results
|
||||||
region: Search region
|
region: Search region
|
||||||
safesearch: Safe search level
|
safesearch: Safe search level
|
||||||
backend: DDGS backend(s), e.g. "auto", "duckduckgo", or "duckduckgo,brave"
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of search results
|
List of search results
|
||||||
@@ -113,15 +39,11 @@ def _search_text(
|
|||||||
ddgs = DDGS(timeout=30)
|
ddgs = DDGS(timeout=30)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
backend = _normalize_backend(backend)
|
|
||||||
safesearch = _normalize_setting(safesearch, DEFAULT_SAFESEARCH)
|
|
||||||
effective_region = _resolve_ddgs_region(query, region, backend)
|
|
||||||
results = ddgs.text(
|
results = ddgs.text(
|
||||||
query,
|
query,
|
||||||
region=effective_region,
|
region=region,
|
||||||
safesearch=safesearch,
|
safesearch=safesearch,
|
||||||
max_results=max_results,
|
max_results=max_results,
|
||||||
backend=backend,
|
|
||||||
)
|
)
|
||||||
return list(results) if results else []
|
return list(results) if results else []
|
||||||
|
|
||||||
@@ -142,23 +64,14 @@ def web_search_tool(
|
|||||||
max_results: Maximum number of results to return. Default is 5.
|
max_results: Maximum number of results to return. Default is 5.
|
||||||
"""
|
"""
|
||||||
config = get_app_config().get_tool_config("web_search")
|
config = get_app_config().get_tool_config("web_search")
|
||||||
region = DEFAULT_REGION
|
|
||||||
safesearch = DEFAULT_SAFESEARCH
|
|
||||||
backend = DEFAULT_BACKEND
|
|
||||||
|
|
||||||
if config is not None:
|
# Override max_results from config if set
|
||||||
# Override tool call defaults from config if set.
|
if config is not None and "max_results" in config.model_extra:
|
||||||
max_results = config.model_extra.get("max_results", max_results)
|
max_results = config.model_extra.get("max_results", max_results)
|
||||||
region = config.model_extra.get("region", region)
|
|
||||||
safesearch = config.model_extra.get("safesearch", safesearch)
|
|
||||||
backend = config.model_extra.get("backend", backend)
|
|
||||||
|
|
||||||
results = _search_text(
|
results = _search_text(
|
||||||
query=query,
|
query=query,
|
||||||
max_results=max_results,
|
max_results=max_results,
|
||||||
region=region,
|
|
||||||
safesearch=safesearch,
|
|
||||||
backend=backend,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ _api_key_warned = False
|
|||||||
|
|
||||||
|
|
||||||
class JinaClient:
|
class JinaClient:
|
||||||
async def crawl(self, url: str, return_format: str = "html", timeout: int = 10, proxy: str | None = None, trust_env: bool = True) -> str:
|
async def crawl(self, url: str, return_format: str = "html", timeout: int = 10) -> str:
|
||||||
global _api_key_warned
|
global _api_key_warned
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
@@ -23,10 +23,7 @@ class JinaClient:
|
|||||||
logger.warning("Jina API key is not set. Provide your own key to access a higher rate limit. See https://jina.ai/reader for more information.")
|
logger.warning("Jina API key is not set. Provide your own key to access a higher rate limit. See https://jina.ai/reader for more information.")
|
||||||
data = {"url": url}
|
data = {"url": url}
|
||||||
try:
|
try:
|
||||||
client_kwargs: dict[str, object] = {"trust_env": trust_env}
|
async with httpx.AsyncClient() as client:
|
||||||
if proxy:
|
|
||||||
client_kwargs["proxy"] = proxy
|
|
||||||
async with httpx.AsyncClient(**client_kwargs) as client:
|
|
||||||
response = await client.post("https://r.jina.ai/", headers=headers, json=data, timeout=timeout)
|
response = await client.post("https://r.jina.ai/", headers=headers, json=data, timeout=timeout)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
|||||||
@@ -9,38 +9,6 @@ from deerflow.utils.readability import ReadabilityExtractor
|
|||||||
readability_extractor = ReadabilityExtractor()
|
readability_extractor = ReadabilityExtractor()
|
||||||
|
|
||||||
|
|
||||||
def _coerce_bool(value: object, default: bool) -> bool:
|
|
||||||
if isinstance(value, bool):
|
|
||||||
return value
|
|
||||||
if isinstance(value, str):
|
|
||||||
normalized = value.strip().lower()
|
|
||||||
if normalized in {"1", "true", "yes", "on"}:
|
|
||||||
return True
|
|
||||||
if normalized in {"0", "false", "no", "off"}:
|
|
||||||
return False
|
|
||||||
return default
|
|
||||||
|
|
||||||
|
|
||||||
def _coerce_timeout(value: object, default: int) -> int:
|
|
||||||
if isinstance(value, bool):
|
|
||||||
return default
|
|
||||||
if isinstance(value, int):
|
|
||||||
return value
|
|
||||||
if isinstance(value, str):
|
|
||||||
try:
|
|
||||||
return int(value)
|
|
||||||
except ValueError:
|
|
||||||
return default
|
|
||||||
return default
|
|
||||||
|
|
||||||
|
|
||||||
def _coerce_proxy(value: object) -> str | None:
|
|
||||||
if not isinstance(value, str):
|
|
||||||
return None
|
|
||||||
proxy = value.strip()
|
|
||||||
return proxy or None
|
|
||||||
|
|
||||||
|
|
||||||
@tool("web_fetch", parse_docstring=True)
|
@tool("web_fetch", parse_docstring=True)
|
||||||
async def web_fetch_tool(url: str) -> str:
|
async def web_fetch_tool(url: str) -> str:
|
||||||
"""Fetch the contents of a web page at a given URL.
|
"""Fetch the contents of a web page at a given URL.
|
||||||
@@ -54,14 +22,10 @@ async def web_fetch_tool(url: str) -> str:
|
|||||||
"""
|
"""
|
||||||
jina_client = JinaClient()
|
jina_client = JinaClient()
|
||||||
timeout = 10
|
timeout = 10
|
||||||
proxy = None
|
|
||||||
trust_env = True
|
|
||||||
config = get_app_config().get_tool_config("web_fetch")
|
config = get_app_config().get_tool_config("web_fetch")
|
||||||
if config is not None:
|
if config is not None and "timeout" in config.model_extra:
|
||||||
timeout = _coerce_timeout(config.model_extra.get("timeout"), timeout)
|
timeout = config.model_extra.get("timeout")
|
||||||
proxy = _coerce_proxy(config.model_extra.get("proxy"))
|
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)
|
||||||
trust_env = _coerce_bool(config.model_extra.get("trust_env"), trust_env)
|
|
||||||
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout, proxy=proxy, trust_env=trust_env)
|
|
||||||
if isinstance(html_content, str) and html_content.startswith("Error:"):
|
if isinstance(html_content, str) and html_content.startswith("Error:"):
|
||||||
return html_content
|
return html_content
|
||||||
article = await asyncio.to_thread(readability_extractor.extract_article, html_content)
|
article = await asyncio.to_thread(readability_extractor.extract_article, html_content)
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
from .tools import web_search_tool
|
|
||||||
|
|
||||||
__all__ = ["web_search_tool"]
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SearxngClient:
|
|
||||||
"""Client for SearXNG meta search engine API."""
|
|
||||||
|
|
||||||
def __init__(self, base_url: str) -> None:
|
|
||||||
self.base_url = base_url.rstrip("/")
|
|
||||||
|
|
||||||
async def search(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
max_results: int = 5,
|
|
||||||
categories: list[str] | None = None,
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""Search the web using SearXNG.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: The search query.
|
|
||||||
max_results: Maximum number of results to return.
|
|
||||||
categories: Search categories to use.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of search result dictionaries.
|
|
||||||
"""
|
|
||||||
params: dict[str, Any] = {
|
|
||||||
"q": query,
|
|
||||||
"format": "json",
|
|
||||||
"language": "auto",
|
|
||||||
"pageno": 1,
|
|
||||||
}
|
|
||||||
if max_results:
|
|
||||||
params["limit"] = max_results
|
|
||||||
if categories:
|
|
||||||
params["categories"] = ",".join(categories)
|
|
||||||
|
|
||||||
logger.debug(f"Searching SearXNG at {self.base_url} with query: {query}")
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=30) as client:
|
|
||||||
resp = await client.get(
|
|
||||||
f"{self.base_url}/search",
|
|
||||||
params=params,
|
|
||||||
headers={
|
|
||||||
"User-Agent": "Mozilla/5.0 (compatible; DeerFlow/1.0)",
|
|
||||||
"Accept": "application/json",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
|
||||||
data = resp.json()
|
|
||||||
results = data.get("results", [])
|
|
||||||
return results[:max_results] if max_results else results
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
logger.error(f"SearXNG search returned error status: {e}")
|
|
||||||
raise
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
logger.error(f"SearXNG search request failed: {e}")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"An unexpected error occurred during SearXNG search: {e}")
|
|
||||||
raise
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from langchain.tools import tool
|
|
||||||
|
|
||||||
from deerflow.config import get_app_config
|
|
||||||
|
|
||||||
from .searxng_client import SearxngClient
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_tool_config(tool_name: str) -> dict | None:
|
|
||||||
"""Get tool config extras safely, returning None if not configured."""
|
|
||||||
config = get_app_config().get_tool_config(tool_name)
|
|
||||||
if config is None:
|
|
||||||
return None
|
|
||||||
extras = config.model_extra
|
|
||||||
return extras if extras is not None else {}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_searxng_client() -> SearxngClient:
|
|
||||||
cfg = _get_tool_config("web_search")
|
|
||||||
base_url = "http://localhost:8088"
|
|
||||||
if cfg is not None:
|
|
||||||
base_url = cfg.get("base_url", base_url)
|
|
||||||
return SearxngClient(base_url=base_url)
|
|
||||||
|
|
||||||
|
|
||||||
@tool("web_search", parse_docstring=True)
|
|
||||||
async def web_search_tool(query: str) -> str:
|
|
||||||
"""Search the web using SearXNG.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: The query to search for.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
cfg = _get_tool_config("web_search")
|
|
||||||
max_results = 5
|
|
||||||
if cfg is not None:
|
|
||||||
raw = cfg.get("max_results", max_results)
|
|
||||||
max_results = int(raw) if not isinstance(raw, int) else raw
|
|
||||||
|
|
||||||
client = _get_searxng_client()
|
|
||||||
results = await client.search(query, max_results=max_results)
|
|
||||||
|
|
||||||
normalized = [
|
|
||||||
{
|
|
||||||
"title": r.get("title", ""),
|
|
||||||
"url": r.get("url", ""),
|
|
||||||
"snippet": r.get("content", ""),
|
|
||||||
}
|
|
||||||
for r in results
|
|
||||||
]
|
|
||||||
return json.dumps(normalized, indent=2, ensure_ascii=False)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in web_search_tool: {e}")
|
|
||||||
return json.dumps({"error": str(e), "query": query}, ensure_ascii=False)
|
|
||||||
@@ -67,13 +67,11 @@ def resolve_agent_dir(name: str, *, user_id: str | None = None) -> Path:
|
|||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
effective_user = user_id or get_effective_user_id()
|
effective_user = user_id or get_effective_user_id()
|
||||||
user_path = paths.user_agent_dir(effective_user, name)
|
user_path = paths.user_agent_dir(effective_user, name)
|
||||||
# Require config.yaml to confirm this is a genuine agent directory,
|
if user_path.exists():
|
||||||
# not a leftover from memory/storage writes (see #3390).
|
|
||||||
if user_path.exists() and (user_path / "config.yaml").exists():
|
|
||||||
return user_path
|
return user_path
|
||||||
|
|
||||||
legacy_path = paths.agent_dir(name)
|
legacy_path = paths.agent_dir(name)
|
||||||
if legacy_path.exists() and (legacy_path / "config.yaml").exists():
|
if legacy_path.exists():
|
||||||
return legacy_path
|
return legacy_path
|
||||||
|
|
||||||
return user_path
|
return user_path
|
||||||
|
|||||||
@@ -7,11 +7,10 @@ from typing import Any, Self
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from deerflow.config.acp_config import ACPAgentConfig, load_acp_config_from_dict
|
from deerflow.config.acp_config import ACPAgentConfig, load_acp_config_from_dict
|
||||||
from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict
|
from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict
|
||||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
|
||||||
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
|
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
|
||||||
from deerflow.config.database_config import DatabaseConfig
|
from deerflow.config.database_config import DatabaseConfig
|
||||||
from deerflow.config.extensions_config import ExtensionsConfig
|
from deerflow.config.extensions_config import ExtensionsConfig
|
||||||
@@ -117,13 +116,6 @@ class AppConfig(BaseModel):
|
|||||||
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
|
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
|
||||||
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
||||||
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
|
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
|
||||||
channel_connections: ChannelConnectionsConfig = Field(
|
|
||||||
default_factory=ChannelConnectionsConfig,
|
|
||||||
description=format_field_description(
|
|
||||||
"channel_connections",
|
|
||||||
field_doc="User-facing IM channel connection configuration.",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
|
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
|
||||||
safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration")
|
safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration")
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
@@ -156,21 +148,6 @@ class AppConfig(BaseModel):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("models", "tools", "tool_groups", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def _coerce_null_list_sections(cls, value: Any) -> Any:
|
|
||||||
"""Treat a present-but-empty config section as an empty list.
|
|
||||||
|
|
||||||
Commenting out every entry under a top-level YAML key — e.g. ``models:``
|
|
||||||
with only comments beneath it, exactly as shipped in
|
|
||||||
``config.example.yaml`` — makes PyYAML parse the value as ``None``.
|
|
||||||
Without this, the documented ``cp config.example.yaml config.yaml``
|
|
||||||
first-run flow crashes with an opaque ``Input should be a valid list``
|
|
||||||
pydantic error. Coercing ``None`` to ``[]`` keeps that flow working and
|
|
||||||
matches the field's own ``default_factory=list``.
|
|
||||||
"""
|
|
||||||
return [] if value is None else value
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def resolve_config_path(cls, config_path: str | None = None) -> Path:
|
def resolve_config_path(cls, config_path: str | None = None) -> Path:
|
||||||
"""Resolve the config file path.
|
"""Resolve the config file path.
|
||||||
@@ -232,11 +209,6 @@ class AppConfig(BaseModel):
|
|||||||
config_data["extensions"] = extensions_config.model_dump()
|
config_data["extensions"] = extensions_config.model_dump()
|
||||||
|
|
||||||
result = cls.model_validate(config_data)
|
result = cls.model_validate(config_data)
|
||||||
if not result.models:
|
|
||||||
logger.warning(
|
|
||||||
"No models are configured in %s. Add at least one entry under `models:` (see the commented examples in config.example.yaml) or run `make setup`.",
|
|
||||||
resolved_path,
|
|
||||||
)
|
|
||||||
acp_agents = cls._validate_acp_agents(config_data.get("acp_agents", {}))
|
acp_agents = cls._validate_acp_agents(config_data.get("acp_agents", {}))
|
||||||
cls._apply_singleton_configs(result, acp_agents)
|
cls._apply_singleton_configs(result, acp_agents)
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -1,61 +0,0 @@
|
|||||||
"""Configuration for user-owned IM channel connections."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class SlackChannelConnectionConfig(BaseModel):
|
|
||||||
enabled: bool = False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def configured(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class TelegramChannelConnectionConfig(BaseModel):
|
|
||||||
enabled: bool = False
|
|
||||||
bot_username: str = ""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def configured(self) -> bool:
|
|
||||||
return bool(self.bot_username)
|
|
||||||
|
|
||||||
|
|
||||||
class DiscordChannelConnectionConfig(BaseModel):
|
|
||||||
enabled: bool = False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def configured(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class BindingCodeChannelConnectionConfig(BaseModel):
|
|
||||||
enabled: bool = False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def configured(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelConnectionsConfig(BaseModel):
|
|
||||||
"""Top-level config for browser-connectable IM channels."""
|
|
||||||
|
|
||||||
enabled: bool = False
|
|
||||||
slack: SlackChannelConnectionConfig = Field(default_factory=SlackChannelConnectionConfig)
|
|
||||||
telegram: TelegramChannelConnectionConfig = Field(default_factory=TelegramChannelConnectionConfig)
|
|
||||||
discord: DiscordChannelConnectionConfig = Field(default_factory=DiscordChannelConnectionConfig)
|
|
||||||
feishu: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
|
|
||||||
dingtalk: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
|
|
||||||
wechat: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
|
|
||||||
wecom: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
|
|
||||||
|
|
||||||
def provider_status(self, provider: str) -> dict[str, bool]:
|
|
||||||
config = getattr(self, provider, None)
|
|
||||||
if config is None:
|
|
||||||
return {"enabled": False, "configured": False}
|
|
||||||
enabled = bool(config.enabled)
|
|
||||||
return {
|
|
||||||
"enabled": enabled,
|
|
||||||
"configured": enabled and bool(config.configured),
|
|
||||||
}
|
|
||||||
@@ -41,20 +41,6 @@ def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
|
|||||||
_checkpointer_config = config
|
_checkpointer_config = config
|
||||||
|
|
||||||
|
|
||||||
def ensure_config_loaded() -> None:
|
|
||||||
"""Lazily load app config when checkpointer config has not been initialized."""
|
|
||||||
from deerflow.config.app_config import _app_config, get_app_config
|
|
||||||
|
|
||||||
config = get_checkpointer_config()
|
|
||||||
if config is not None or _app_config is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
get_app_config()
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def load_checkpointer_config_from_dict(config_dict: dict | None) -> None:
|
def load_checkpointer_config_from_dict(config_dict: dict | None) -> None:
|
||||||
"""Load checkpointer configuration from a dictionary."""
|
"""Load checkpointer configuration from a dictionary."""
|
||||||
global _checkpointer_config
|
global _checkpointer_config
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
"""Configuration for memory mechanism."""
|
"""Configuration for memory mechanism."""
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
@@ -62,17 +60,6 @@ class MemoryConfig(BaseModel):
|
|||||||
le=8000,
|
le=8000,
|
||||||
description="Maximum tokens to use for memory injection",
|
description="Maximum tokens to use for memory injection",
|
||||||
)
|
)
|
||||||
token_counting: Literal["tiktoken", "char"] = Field(
|
|
||||||
default="tiktoken",
|
|
||||||
description=(
|
|
||||||
"Token counting strategy for memory-injection budgeting. "
|
|
||||||
"'tiktoken' is accurate but the encoding's BPE data may be "
|
|
||||||
"downloaded from a public network endpoint on first use, which "
|
|
||||||
"can block for a long time in network-restricted environments "
|
|
||||||
"(see issue #3402/#3429). 'char' uses a network-free "
|
|
||||||
"CJK-aware character-based estimate and never touches tiktoken."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Global configuration instance
|
# Global configuration instance
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
@@ -15,8 +14,6 @@ _SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
|||||||
_UNSAFE_USER_ID_CHAR_RE = re.compile(r"[^A-Za-z0-9_\-]")
|
_UNSAFE_USER_ID_CHAR_RE = re.compile(r"[^A-Za-z0-9_\-]")
|
||||||
_SAFE_USER_ID_DIGEST_HEX_LEN = 16
|
_SAFE_USER_ID_DIGEST_HEX_LEN = 16
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _default_local_base_dir() -> Path:
|
def _default_local_base_dir() -> Path:
|
||||||
"""Return the caller project's writable DeerFlow state directory."""
|
"""Return the caller project's writable DeerFlow state directory."""
|
||||||
@@ -50,13 +47,7 @@ def make_safe_user_id(raw: str) -> str:
|
|||||||
sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw)
|
sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw)
|
||||||
if sanitized == raw:
|
if sanitized == raw:
|
||||||
return raw
|
return raw
|
||||||
digest = hashlib.sha256(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN]
|
digest = hashlib.sha1(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN]
|
||||||
return f"{sanitized}-{digest}"
|
|
||||||
|
|
||||||
|
|
||||||
def _legacy_safe_user_id(raw: str, sanitized: str) -> str:
|
|
||||||
"""Bucket name produced by the previous (SHA-1) digest revision for ``raw``."""
|
|
||||||
digest = hashlib.sha1(raw.encode("utf-8"), usedforsecurity=False).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN]
|
|
||||||
return f"{sanitized}-{digest}"
|
return f"{sanitized}-{digest}"
|
||||||
|
|
||||||
|
|
||||||
@@ -181,32 +172,6 @@ class Paths:
|
|||||||
"""Directory for a specific user: `{base_dir}/users/{user_id}/`."""
|
"""Directory for a specific user: `{base_dir}/users/{user_id}/`."""
|
||||||
return self.base_dir / "users" / _validate_user_id(user_id)
|
return self.base_dir / "users" / _validate_user_id(user_id)
|
||||||
|
|
||||||
def prepare_user_dir_for_raw_id(self, raw_user_id: str) -> str:
|
|
||||||
"""Return the safe user ID and migrate this ID's legacy unsafe-id bucket.
|
|
||||||
|
|
||||||
A previous branch revision used SHA-1 for unsafe external user IDs.
|
|
||||||
New IDs use SHA-256; the legacy bucket name is recomputed from the same
|
|
||||||
raw ID, so only this user's own old bucket can ever be moved — a
|
|
||||||
different raw ID sharing the sanitized prefix produces a different
|
|
||||||
legacy digest and is never touched.
|
|
||||||
"""
|
|
||||||
safe_user_id = make_safe_user_id(raw_user_id)
|
|
||||||
sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw_user_id)
|
|
||||||
if safe_user_id == raw_user_id:
|
|
||||||
return safe_user_id
|
|
||||||
|
|
||||||
users_dir = self.base_dir / "users"
|
|
||||||
target_dir = users_dir / safe_user_id
|
|
||||||
legacy_dir = users_dir / _legacy_safe_user_id(raw_user_id, sanitized)
|
|
||||||
try:
|
|
||||||
if target_dir.exists() or not legacy_dir.is_dir():
|
|
||||||
return safe_user_id
|
|
||||||
legacy_dir.rename(target_dir)
|
|
||||||
logger.info("Migrated legacy unsafe-id user directory to the current digest format")
|
|
||||||
except OSError:
|
|
||||||
logger.exception("Failed to migrate legacy unsafe-id user directory")
|
|
||||||
return safe_user_id
|
|
||||||
|
|
||||||
def user_memory_file(self, user_id: str) -> Path:
|
def user_memory_file(self, user_id: str) -> Path:
|
||||||
"""Per-user memory file: `{base_dir}/users/{user_id}/memory.json`."""
|
"""Per-user memory file: `{base_dir}/users/{user_id}/memory.json`."""
|
||||||
return self.user_dir(user_id) / "memory.json"
|
return self.user_dir(user_id) / "memory.json"
|
||||||
|
|||||||
@@ -56,9 +56,6 @@ STARTUP_ONLY_FIELDS: dict[str, str] = {
|
|||||||
# startup and the live channel clients are not rebuilt on
|
# startup and the live channel clients are not rebuilt on
|
||||||
# config.yaml edits.
|
# config.yaml edits.
|
||||||
"channels": ("start_channel_service() is invoked once during startup; the live IM channel clients (Feishu, Slack, Telegram, DingTalk) are not rebuilt when channels.* changes."),
|
"channels": ("start_channel_service() is invoked once during startup; the live IM channel clients (Feishu, Slack, Telegram, DingTalk) are not rebuilt when channels.* changes."),
|
||||||
"channel_connections": (
|
|
||||||
"start_channel_service() wires the connection repository and channel workers once at startup, and the channel-connections router caches the merged provider config on app.state; channel_connections.* edits need a restart."
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,20 +4,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||||||
class VolumeMountConfig(BaseModel):
|
class VolumeMountConfig(BaseModel):
|
||||||
"""Configuration for a volume mount."""
|
"""Configuration for a volume mount."""
|
||||||
|
|
||||||
host_path: str = Field(
|
host_path: str = Field(..., description="Path on the host machine")
|
||||||
...,
|
|
||||||
description=(
|
|
||||||
"Source path for the mount. Resolution depends on the active provider: "
|
|
||||||
"``LocalSandboxProvider`` checks this path from the gateway process — in "
|
|
||||||
"``make dev`` that is the host machine, but in Docker deployments "
|
|
||||||
"(``make up`` / docker-compose) it is the path *inside* the "
|
|
||||||
"``deer-flow-gateway`` container, so the host directory must also be "
|
|
||||||
"bind-mounted into the gateway service for the mount to take effect. "
|
|
||||||
"``AioSandboxProvider`` (DooD) passes this value straight to ``docker -v`` "
|
|
||||||
"for the sandbox container, where it is resolved by the host Docker daemon "
|
|
||||||
"from the host machine's perspective."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
container_path: str = Field(..., description="Path inside the container")
|
container_path: str = Field(..., description="Path inside the container")
|
||||||
read_only: bool = Field(default=False, description="Whether the mount is read-only")
|
read_only: bool = Field(default=False, description="Whether the mount is read-only")
|
||||||
|
|
||||||
|
|||||||
@@ -114,27 +114,8 @@ class PatchedChatMiniMax(ChatOpenAI):
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
payload["extra_body"] = {"reasoning_split": True}
|
payload["extra_body"] = {"reasoning_split": True}
|
||||||
self._strip_user_message_names(payload)
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _strip_user_message_names(payload: dict) -> None:
|
|
||||||
"""Drop the per-message ``name`` field from user-role messages.
|
|
||||||
|
|
||||||
DeerFlow middlewares tag user messages with internal provenance names
|
|
||||||
(``user-input``, ``summary``, ``loop_warning``, ...). ``langchain_openai``
|
|
||||||
serializes those into the OpenAI-compatible request, but MiniMax requires
|
|
||||||
every user-role ``name`` to be identical and otherwise rejects the request
|
|
||||||
with ``invalid params, user name must be consistent (2013)``. MiniMax does
|
|
||||||
not use the per-message author name, so strip it.
|
|
||||||
"""
|
|
||||||
messages = payload.get("messages")
|
|
||||||
if not isinstance(messages, list):
|
|
||||||
return
|
|
||||||
for message in messages:
|
|
||||||
if isinstance(message, dict) and message.get("role") == "user":
|
|
||||||
message.pop("name", None)
|
|
||||||
|
|
||||||
def _convert_chunk_to_generation_chunk(
|
def _convert_chunk_to_generation_chunk(
|
||||||
self,
|
self,
|
||||||
chunk: dict,
|
chunk: dict,
|
||||||
|
|||||||
@@ -1,175 +0,0 @@
|
|||||||
"""Patched ChatOpenAI adapter for StepFun reasoning models.
|
|
||||||
|
|
||||||
StepFun returns ``reasoning`` (or ``reasoning_content`` with deepseek-style) in
|
|
||||||
both streaming deltas and non-streaming responses. Standard ``ChatOpenAI``
|
|
||||||
ignores these non-standard fields, so reasoning content is silently dropped.
|
|
||||||
This adapter captures reasoning from all response paths and replays it on
|
|
||||||
historical assistant messages for multi-turn tool-call conversations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Mapping
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.language_models import LanguageModelInput
|
|
||||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
from deerflow.models.assistant_payload_replay import (
|
|
||||||
restore_assistant_payloads,
|
|
||||||
restore_reasoning_content,
|
|
||||||
)
|
|
||||||
|
|
||||||
_MISSING = object()
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_reasoning(value: Any) -> str | object:
|
|
||||||
"""Return reasoning content from a dict/Pydantic object.
|
|
||||||
|
|
||||||
StepFun may return reasoning via ``reasoning`` (default) or
|
|
||||||
``reasoning_content`` (deepseek-style). Check both fields.
|
|
||||||
"""
|
|
||||||
if isinstance(value, Mapping):
|
|
||||||
# Check reasoning_content first (deepseek-style), then reasoning (default)
|
|
||||||
for field in ("reasoning_content", "reasoning"):
|
|
||||||
if field in value and value[field] is not None:
|
|
||||||
return value[field]
|
|
||||||
return _MISSING
|
|
||||||
|
|
||||||
# Pydantic / SDK object attributes
|
|
||||||
for field in ("reasoning_content", "reasoning"):
|
|
||||||
attr = getattr(value, field, _MISSING)
|
|
||||||
if attr is not _MISSING and attr is not None:
|
|
||||||
return attr
|
|
||||||
|
|
||||||
# Some SDK versions store extra fields in model_extra
|
|
||||||
model_extra = getattr(value, "model_extra", None)
|
|
||||||
if isinstance(model_extra, Mapping):
|
|
||||||
for field in ("reasoning_content", "reasoning"):
|
|
||||||
if field in model_extra and model_extra[field] is not None:
|
|
||||||
return model_extra[field]
|
|
||||||
|
|
||||||
return _MISSING
|
|
||||||
|
|
||||||
|
|
||||||
def _with_reasoning_content(message: AIMessage | AIMessageChunk, reasoning: str) -> AIMessage | AIMessageChunk:
|
|
||||||
"""Return a copy of *message* with reasoning_content stored in additional_kwargs."""
|
|
||||||
additional_kwargs = dict(message.additional_kwargs)
|
|
||||||
if additional_kwargs.get("reasoning_content") != reasoning:
|
|
||||||
additional_kwargs["reasoning_content"] = reasoning
|
|
||||||
return message.model_copy(update={"additional_kwargs": additional_kwargs})
|
|
||||||
|
|
||||||
|
|
||||||
def _get_typed_choice_message(response: Any, index: int) -> Any:
|
|
||||||
"""Extract the SDK-typed choice message at *index*, if available."""
|
|
||||||
choices = getattr(response, "choices", None)
|
|
||||||
if choices is None:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
return choices[index].message
|
|
||||||
except (AttributeError, IndexError, TypeError):
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class PatchedChatStepFun(ChatOpenAI):
|
|
||||||
"""ChatOpenAI with full reasoning support for StepFun models.
|
|
||||||
|
|
||||||
Captures ``reasoning`` / ``reasoning_content`` from both streaming and
|
|
||||||
non-streaming responses and replays it on historical assistant messages in
|
|
||||||
multi-turn tool-call conversations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_lc_serializable(cls) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def lc_secrets(self) -> dict[str, str]:
|
|
||||||
return {"api_key": "STEPFUN_API_KEY", "openai_api_key": "STEPFUN_API_KEY"}
|
|
||||||
|
|
||||||
# --- Request payload replay ---
|
|
||||||
|
|
||||||
def _get_request_payload(
|
|
||||||
self,
|
|
||||||
input_: LanguageModelInput,
|
|
||||||
*,
|
|
||||||
stop: list[str] | None = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> dict:
|
|
||||||
"""Restore ``reasoning_content`` on historical assistant messages."""
|
|
||||||
original_messages = self._convert_input(input_).to_messages()
|
|
||||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
|
||||||
|
|
||||||
restore_assistant_payloads(
|
|
||||||
payload.get("messages", []),
|
|
||||||
original_messages,
|
|
||||||
restore_reasoning_content,
|
|
||||||
)
|
|
||||||
|
|
||||||
return payload
|
|
||||||
|
|
||||||
# --- Streaming reasoning capture ---
|
|
||||||
|
|
||||||
def _convert_chunk_to_generation_chunk(
|
|
||||||
self,
|
|
||||||
chunk: dict,
|
|
||||||
default_chunk_class: type,
|
|
||||||
base_generation_info: dict | None,
|
|
||||||
) -> ChatGenerationChunk | None:
|
|
||||||
"""Capture ``reasoning`` / ``reasoning_content`` from streaming deltas."""
|
|
||||||
generation_chunk = super()._convert_chunk_to_generation_chunk(
|
|
||||||
chunk,
|
|
||||||
default_chunk_class,
|
|
||||||
base_generation_info,
|
|
||||||
)
|
|
||||||
if generation_chunk is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
choices = chunk.get("choices", [])
|
|
||||||
if choices:
|
|
||||||
delta = choices[0].get("delta") or {}
|
|
||||||
reasoning = _extract_reasoning(delta)
|
|
||||||
if reasoning is not _MISSING and isinstance(generation_chunk.message, AIMessageChunk):
|
|
||||||
generation_chunk = ChatGenerationChunk(
|
|
||||||
message=_with_reasoning_content(generation_chunk.message, reasoning),
|
|
||||||
generation_info=generation_chunk.generation_info,
|
|
||||||
)
|
|
||||||
|
|
||||||
return generation_chunk
|
|
||||||
|
|
||||||
# --- Non-streaming reasoning capture ---
|
|
||||||
|
|
||||||
def _create_chat_result(
|
|
||||||
self,
|
|
||||||
response: dict | Any,
|
|
||||||
generation_info: dict | None = None,
|
|
||||||
) -> ChatResult:
|
|
||||||
"""Extract ``reasoning`` / ``reasoning_content`` from non-streaming responses."""
|
|
||||||
result = super()._create_chat_result(response, generation_info)
|
|
||||||
response_dict = response if isinstance(response, dict) else response.model_dump()
|
|
||||||
choices = response_dict.get("choices", [])
|
|
||||||
|
|
||||||
patched_generations: list[ChatGeneration] | None = None
|
|
||||||
for index, generation in enumerate(result.generations):
|
|
||||||
choice = choices[index] if index < len(choices) else {}
|
|
||||||
choice_message = choice.get("message", {}) if isinstance(choice, Mapping) else {}
|
|
||||||
reasoning = _extract_reasoning(choice_message)
|
|
||||||
|
|
||||||
if reasoning is _MISSING and not isinstance(response, dict):
|
|
||||||
reasoning = _extract_reasoning(_get_typed_choice_message(response, index))
|
|
||||||
|
|
||||||
message = generation.message
|
|
||||||
if reasoning is not _MISSING and isinstance(message, AIMessage):
|
|
||||||
if patched_generations is None:
|
|
||||||
patched_generations = list(result.generations)
|
|
||||||
patched_generations[index] = ChatGeneration(
|
|
||||||
message=_with_reasoning_content(message, reasoning),
|
|
||||||
generation_info=generation.generation_info,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ChatResult(
|
|
||||||
generations=patched_generations or result.generations,
|
|
||||||
llm_output=result.llm_output,
|
|
||||||
)
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
"""User-owned IM channel connection persistence."""
|
|
||||||
|
|
||||||
from deerflow.persistence.channel_connections.model import (
|
|
||||||
ChannelConnectionRow,
|
|
||||||
ChannelConversationRow,
|
|
||||||
ChannelCredentialRow,
|
|
||||||
ChannelOAuthStateRow,
|
|
||||||
)
|
|
||||||
from deerflow.persistence.channel_connections.sql import (
|
|
||||||
ChannelConnectionRepository,
|
|
||||||
ChannelCredentialCipher,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ChannelConnectionRepository",
|
|
||||||
"ChannelConnectionRow",
|
|
||||||
"ChannelConversationRow",
|
|
||||||
"ChannelCredentialCipher",
|
|
||||||
"ChannelCredentialRow",
|
|
||||||
"ChannelOAuthStateRow",
|
|
||||||
]
|
|
||||||
@@ -1,111 +0,0 @@
|
|||||||
"""ORM models for user-owned IM channel connections."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
|
|
||||||
from sqlalchemy import JSON, DateTime, ForeignKey, Index, Integer, String, Text, UniqueConstraint
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
|
||||||
|
|
||||||
from deerflow.persistence.base import Base
|
|
||||||
|
|
||||||
|
|
||||||
def _utc_now() -> datetime:
|
|
||||||
return datetime.now(UTC)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelConnectionRow(Base):
|
|
||||||
__tablename__ = "channel_connections"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
|
||||||
owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
|
||||||
provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
|
|
||||||
status: Mapped[str] = mapped_column(String(32), nullable=False, default="connected")
|
|
||||||
|
|
||||||
external_account_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
|
||||||
external_account_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
|
||||||
workspace_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
|
||||||
workspace_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
|
||||||
bot_user_id: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
|
||||||
|
|
||||||
scopes_json: Mapped[list] = mapped_column(JSON, default=list)
|
|
||||||
capabilities_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
|
||||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
|
||||||
|
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now)
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now)
|
|
||||||
last_seen_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
||||||
last_error_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
UniqueConstraint(
|
|
||||||
"owner_user_id",
|
|
||||||
"provider",
|
|
||||||
"external_account_id",
|
|
||||||
"workspace_id",
|
|
||||||
name="uq_channel_connection_owner_provider_identity",
|
|
||||||
),
|
|
||||||
Index("idx_channel_connections_event_lookup", "provider", "workspace_id", "bot_user_id"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelCredentialRow(Base):
|
|
||||||
__tablename__ = "channel_credentials"
|
|
||||||
|
|
||||||
connection_id: Mapped[str] = mapped_column(
|
|
||||||
String(64),
|
|
||||||
ForeignKey("channel_connections.id", ondelete="CASCADE"),
|
|
||||||
primary_key=True,
|
|
||||||
)
|
|
||||||
encrypted_access_token: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
encrypted_refresh_token: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
token_type: Mapped[str | None] = mapped_column(String(32), nullable=True)
|
|
||||||
expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
||||||
refresh_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
||||||
encrypted_extra_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
version: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelOAuthStateRow(Base):
|
|
||||||
__tablename__ = "channel_oauth_states"
|
|
||||||
|
|
||||||
state_hash: Mapped[str] = mapped_column(String(128), primary_key=True)
|
|
||||||
owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
|
||||||
provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
|
|
||||||
code_verifier_encrypted: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
nonce_hash: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
|
||||||
redirect_after: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
requested_scopes_json: Mapped[list] = mapped_column(JSON, default=list)
|
|
||||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
|
||||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
|
||||||
consumed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelConversationRow(Base):
|
|
||||||
__tablename__ = "channel_conversations"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
|
||||||
connection_id: Mapped[str] = mapped_column(
|
|
||||||
String(64),
|
|
||||||
ForeignKey("channel_connections.id", ondelete="CASCADE"),
|
|
||||||
nullable=False,
|
|
||||||
index=True,
|
|
||||||
)
|
|
||||||
owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
|
||||||
provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
|
|
||||||
external_conversation_id: Mapped[str] = mapped_column(String(128), nullable=False)
|
|
||||||
external_topic_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
|
||||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now)
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
UniqueConstraint(
|
|
||||||
"connection_id",
|
|
||||||
"external_conversation_id",
|
|
||||||
"external_topic_id",
|
|
||||||
name="uq_channel_conversation_connection_external",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@@ -1,387 +0,0 @@
|
|||||||
"""SQL repository for user-owned IM channel connections."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from cryptography.fernet import Fernet, InvalidToken
|
|
||||||
from sqlalchemy import delete, func, select, update
|
|
||||||
from sqlalchemy.exc import IntegrityError
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
||||||
|
|
||||||
from deerflow.persistence.channel_connections.model import (
|
|
||||||
ChannelConnectionRow,
|
|
||||||
ChannelConversationRow,
|
|
||||||
ChannelCredentialRow,
|
|
||||||
ChannelOAuthStateRow,
|
|
||||||
)
|
|
||||||
from deerflow.utils.time import coerce_iso
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelCredentialCipher:
|
|
||||||
"""Encrypts provider credentials before they are persisted."""
|
|
||||||
|
|
||||||
def __init__(self, fernet: Fernet) -> None:
|
|
||||||
self._fernet = fernet
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_key(cls, key: str) -> ChannelCredentialCipher:
|
|
||||||
digest = hashlib.sha256(key.encode("utf-8")).digest()
|
|
||||||
return cls(Fernet(base64.urlsafe_b64encode(digest)))
|
|
||||||
|
|
||||||
def encrypt_text(self, value: str | None) -> str | None:
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
return "fernet:v1:" + self._fernet.encrypt(value.encode("utf-8")).decode("ascii")
|
|
||||||
|
|
||||||
def decrypt_text(self, value: str | None) -> str | None:
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
token = value.removeprefix("fernet:v1:")
|
|
||||||
return self._fernet.decrypt(token.encode("ascii")).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelConnectionRepository:
|
|
||||||
"""Persistence facade for channel connections, credentials, and conversations."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
session_factory: async_sessionmaker[AsyncSession],
|
|
||||||
*,
|
|
||||||
cipher: ChannelCredentialCipher | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.session_factory = session_factory
|
|
||||||
self._cipher = cipher
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
from deerflow.persistence.engine import close_engine
|
|
||||||
|
|
||||||
await close_engine()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _new_id() -> str:
|
|
||||||
return uuid.uuid4().hex
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _normalize_optional_identity(value: str | None) -> str:
|
|
||||||
return value or ""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _coerce_datetime(value: datetime | None) -> datetime | None:
|
|
||||||
if value is None or value.tzinfo is not None:
|
|
||||||
return value
|
|
||||||
return value.replace(tzinfo=UTC)
|
|
||||||
|
|
||||||
def _encrypt_optional_secret(self, value: str | None) -> str | None:
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
if self._cipher is None:
|
|
||||||
raise RuntimeError("channel connection encryption key is required")
|
|
||||||
return self._cipher.encrypt_text(value)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _connection_to_dict(row: ChannelConnectionRow) -> dict[str, Any]:
|
|
||||||
data = row.to_dict()
|
|
||||||
data["external_account_id"] = data["external_account_id"] or None
|
|
||||||
data["workspace_id"] = data["workspace_id"] or None
|
|
||||||
data["scopes"] = data.pop("scopes_json") or []
|
|
||||||
data["capabilities"] = data.pop("capabilities_json") or {}
|
|
||||||
data["metadata"] = data.pop("metadata_json") or {}
|
|
||||||
for key in ("created_at", "updated_at", "last_seen_at", "last_error_at"):
|
|
||||||
value = data.get(key)
|
|
||||||
if isinstance(value, datetime):
|
|
||||||
data[key] = coerce_iso(value)
|
|
||||||
return data
|
|
||||||
|
|
||||||
async def upsert_connection(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
owner_user_id: str,
|
|
||||||
provider: str,
|
|
||||||
external_account_id: str | None = None,
|
|
||||||
external_account_name: str | None = None,
|
|
||||||
workspace_id: str | None = None,
|
|
||||||
workspace_name: str | None = None,
|
|
||||||
bot_user_id: str | None = None,
|
|
||||||
scopes: list[str] | None = None,
|
|
||||||
capabilities: dict[str, Any] | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
status: str = "connected",
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
external_account_id_value = self._normalize_optional_identity(external_account_id)
|
|
||||||
workspace_id_value = self._normalize_optional_identity(workspace_id)
|
|
||||||
|
|
||||||
def _apply(row: ChannelConnectionRow) -> None:
|
|
||||||
row.status = status
|
|
||||||
row.external_account_name = external_account_name
|
|
||||||
row.workspace_name = workspace_name
|
|
||||||
row.bot_user_id = bot_user_id
|
|
||||||
row.scopes_json = list(scopes or [])
|
|
||||||
row.capabilities_json = dict(capabilities or {})
|
|
||||||
row.metadata_json = dict(metadata or {})
|
|
||||||
|
|
||||||
stmt = select(ChannelConnectionRow).where(
|
|
||||||
ChannelConnectionRow.owner_user_id == owner_user_id,
|
|
||||||
ChannelConnectionRow.provider == provider,
|
|
||||||
ChannelConnectionRow.external_account_id == external_account_id_value,
|
|
||||||
ChannelConnectionRow.workspace_id == workspace_id_value,
|
|
||||||
)
|
|
||||||
async with self.session_factory() as session:
|
|
||||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
row = ChannelConnectionRow(
|
|
||||||
id=self._new_id(),
|
|
||||||
owner_user_id=owner_user_id,
|
|
||||||
provider=provider,
|
|
||||||
external_account_id=external_account_id_value,
|
|
||||||
workspace_id=workspace_id_value,
|
|
||||||
)
|
|
||||||
session.add(row)
|
|
||||||
|
|
||||||
_apply(row)
|
|
||||||
try:
|
|
||||||
await session.commit()
|
|
||||||
except IntegrityError:
|
|
||||||
# A concurrent writer inserted the same identity first; retry as
|
|
||||||
# an update of that row.
|
|
||||||
await session.rollback()
|
|
||||||
row = (await session.execute(stmt)).scalar_one()
|
|
||||||
_apply(row)
|
|
||||||
await session.commit()
|
|
||||||
await session.refresh(row)
|
|
||||||
return self._connection_to_dict(row)
|
|
||||||
|
|
||||||
async def list_connections(self, owner_user_id: str) -> list[dict[str, Any]]:
|
|
||||||
async with self.session_factory() as session:
|
|
||||||
result = await session.execute(select(ChannelConnectionRow).where(ChannelConnectionRow.owner_user_id == owner_user_id).order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc()))
|
|
||||||
return [self._connection_to_dict(row) for row in result.scalars()]
|
|
||||||
|
|
||||||
async def disconnect_connection(self, *, connection_id: str, owner_user_id: str) -> bool:
|
|
||||||
async with self.session_factory() as session:
|
|
||||||
row = await session.get(ChannelConnectionRow, connection_id)
|
|
||||||
if row is None or row.owner_user_id != owner_user_id:
|
|
||||||
return False
|
|
||||||
|
|
||||||
row.status = "revoked"
|
|
||||||
credential = await session.get(ChannelCredentialRow, connection_id)
|
|
||||||
if credential is not None:
|
|
||||||
await session.delete(credential)
|
|
||||||
await session.commit()
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def store_credentials(
|
|
||||||
self,
|
|
||||||
connection_id: str,
|
|
||||||
*,
|
|
||||||
access_token: str | None,
|
|
||||||
refresh_token: str | None = None,
|
|
||||||
token_type: str | None = None,
|
|
||||||
expires_at: datetime | None = None,
|
|
||||||
refresh_expires_at: datetime | None = None,
|
|
||||||
extra: dict[str, Any] | None = None,
|
|
||||||
) -> None:
|
|
||||||
if self._cipher is None:
|
|
||||||
raise RuntimeError("channel connection encryption key is required")
|
|
||||||
async with self.session_factory() as session:
|
|
||||||
row = await session.get(ChannelCredentialRow, connection_id)
|
|
||||||
if row is None:
|
|
||||||
row = ChannelCredentialRow(connection_id=connection_id)
|
|
||||||
session.add(row)
|
|
||||||
row.encrypted_access_token = self._cipher.encrypt_text(access_token)
|
|
||||||
row.encrypted_refresh_token = self._cipher.encrypt_text(refresh_token)
|
|
||||||
row.token_type = token_type
|
|
||||||
row.expires_at = expires_at
|
|
||||||
row.refresh_expires_at = refresh_expires_at
|
|
||||||
row.encrypted_extra_json = self._cipher.encrypt_text(json.dumps(extra or {}, ensure_ascii=False))
|
|
||||||
row.version = (row.version or 0) + 1
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def get_credentials(self, connection_id: str) -> dict[str, Any] | None:
|
|
||||||
if self._cipher is None:
|
|
||||||
return None
|
|
||||||
async with self.session_factory() as session:
|
|
||||||
row = await session.get(ChannelCredentialRow, connection_id)
|
|
||||||
if row is None:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
extra_raw = self._cipher.decrypt_text(row.encrypted_extra_json)
|
|
||||||
return {
|
|
||||||
"connection_id": row.connection_id,
|
|
||||||
"access_token": self._cipher.decrypt_text(row.encrypted_access_token),
|
|
||||||
"refresh_token": self._cipher.decrypt_text(row.encrypted_refresh_token),
|
|
||||||
"token_type": row.token_type,
|
|
||||||
"expires_at": self._coerce_datetime(row.expires_at),
|
|
||||||
"refresh_expires_at": self._coerce_datetime(row.refresh_expires_at),
|
|
||||||
"extra": json.loads(extra_raw) if extra_raw else {},
|
|
||||||
}
|
|
||||||
except (InvalidToken, UnicodeError, json.JSONDecodeError):
|
|
||||||
logger.warning(
|
|
||||||
"Unable to decrypt channel connection credentials; treating credentials as unavailable",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def hash_state(state: str) -> str:
|
|
||||||
return hashlib.sha256(state.encode("utf-8")).hexdigest()
|
|
||||||
|
|
||||||
async def create_oauth_state(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
owner_user_id: str,
|
|
||||||
provider: str,
|
|
||||||
state: str,
|
|
||||||
expires_at: datetime,
|
|
||||||
code_verifier: str | None = None,
|
|
||||||
nonce_hash: str | None = None,
|
|
||||||
redirect_after: str | None = None,
|
|
||||||
requested_scopes: list[str] | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
) -> None:
|
|
||||||
row = ChannelOAuthStateRow(
|
|
||||||
state_hash=self.hash_state(state),
|
|
||||||
owner_user_id=owner_user_id,
|
|
||||||
provider=provider,
|
|
||||||
code_verifier_encrypted=self._encrypt_optional_secret(code_verifier),
|
|
||||||
nonce_hash=nonce_hash,
|
|
||||||
redirect_after=redirect_after,
|
|
||||||
requested_scopes_json=list(requested_scopes or []),
|
|
||||||
metadata_json=dict(metadata or {}),
|
|
||||||
expires_at=expires_at,
|
|
||||||
)
|
|
||||||
async with self.session_factory() as session:
|
|
||||||
session.add(row)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def count_oauth_states(self, *, owner_user_id: str, provider: str) -> int:
|
|
||||||
async with self.session_factory() as session:
|
|
||||||
result = await session.execute(
|
|
||||||
select(func.count())
|
|
||||||
.select_from(ChannelOAuthStateRow)
|
|
||||||
.where(
|
|
||||||
ChannelOAuthStateRow.owner_user_id == owner_user_id,
|
|
||||||
ChannelOAuthStateRow.provider == provider,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return int(result.scalar_one())
|
|
||||||
|
|
||||||
async def consume_oauth_state(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
provider: str,
|
|
||||||
state: str,
|
|
||||||
now: datetime | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
current_time = now or datetime.now(UTC)
|
|
||||||
state_hash = self.hash_state(state)
|
|
||||||
async with self.session_factory() as session:
|
|
||||||
await session.execute(delete(ChannelOAuthStateRow).where(ChannelOAuthStateRow.expires_at < current_time))
|
|
||||||
row = await session.get(ChannelOAuthStateRow, state_hash)
|
|
||||||
if row is None or row.provider != provider or row.consumed_at is not None:
|
|
||||||
await session.commit()
|
|
||||||
return None
|
|
||||||
expires_at = self._coerce_datetime(row.expires_at)
|
|
||||||
if expires_at is not None and expires_at < current_time:
|
|
||||||
await session.commit()
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Conditional UPDATE so two concurrent workers cannot both consume
|
|
||||||
# the same binding code: only the writer that flips consumed_at
|
|
||||||
# from NULL wins.
|
|
||||||
result = await session.execute(
|
|
||||||
update(ChannelOAuthStateRow)
|
|
||||||
.where(
|
|
||||||
ChannelOAuthStateRow.state_hash == state_hash,
|
|
||||||
ChannelOAuthStateRow.consumed_at.is_(None),
|
|
||||||
)
|
|
||||||
.values(consumed_at=current_time)
|
|
||||||
)
|
|
||||||
await session.commit()
|
|
||||||
if result.rowcount != 1:
|
|
||||||
return None
|
|
||||||
return {
|
|
||||||
"owner_user_id": row.owner_user_id,
|
|
||||||
"provider": row.provider,
|
|
||||||
"requested_scopes": row.requested_scopes_json or [],
|
|
||||||
"metadata": row.metadata_json or {},
|
|
||||||
"redirect_after": row.redirect_after,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def find_connection_by_external_identity(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
provider: str,
|
|
||||||
external_account_id: str,
|
|
||||||
workspace_id: str | None = None,
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
async with self.session_factory() as session:
|
|
||||||
result = await session.execute(
|
|
||||||
select(ChannelConnectionRow)
|
|
||||||
.where(
|
|
||||||
ChannelConnectionRow.provider == provider,
|
|
||||||
ChannelConnectionRow.external_account_id == self._normalize_optional_identity(external_account_id),
|
|
||||||
ChannelConnectionRow.workspace_id == self._normalize_optional_identity(workspace_id),
|
|
||||||
ChannelConnectionRow.status == "connected",
|
|
||||||
)
|
|
||||||
.order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc())
|
|
||||||
.limit(1)
|
|
||||||
)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
return self._connection_to_dict(row) if row is not None else None
|
|
||||||
|
|
||||||
async def set_thread_id(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
connection_id: str,
|
|
||||||
owner_user_id: str,
|
|
||||||
provider: str,
|
|
||||||
external_conversation_id: str,
|
|
||||||
thread_id: str,
|
|
||||||
external_topic_id: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
topic_id = external_topic_id or ""
|
|
||||||
async with self.session_factory() as session:
|
|
||||||
stmt = select(ChannelConversationRow).where(
|
|
||||||
ChannelConversationRow.connection_id == connection_id,
|
|
||||||
ChannelConversationRow.external_conversation_id == external_conversation_id,
|
|
||||||
ChannelConversationRow.external_topic_id == topic_id,
|
|
||||||
)
|
|
||||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
row = ChannelConversationRow(
|
|
||||||
id=self._new_id(),
|
|
||||||
connection_id=connection_id,
|
|
||||||
owner_user_id=owner_user_id,
|
|
||||||
provider=provider,
|
|
||||||
external_conversation_id=external_conversation_id,
|
|
||||||
external_topic_id=topic_id,
|
|
||||||
thread_id=thread_id,
|
|
||||||
)
|
|
||||||
session.add(row)
|
|
||||||
else:
|
|
||||||
row.thread_id = thread_id
|
|
||||||
row.owner_user_id = owner_user_id
|
|
||||||
row.provider = provider
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def get_thread_id(
|
|
||||||
self,
|
|
||||||
connection_id: str,
|
|
||||||
external_conversation_id: str,
|
|
||||||
external_topic_id: str | None = None,
|
|
||||||
) -> str | None:
|
|
||||||
async with self.session_factory() as session:
|
|
||||||
stmt = select(ChannelConversationRow.thread_id).where(
|
|
||||||
ChannelConversationRow.connection_id == connection_id,
|
|
||||||
ChannelConversationRow.external_conversation_id == external_conversation_id,
|
|
||||||
ChannelConversationRow.external_topic_id == (external_topic_id or ""),
|
|
||||||
)
|
|
||||||
return (await session.execute(stmt)).scalar_one_or_none()
|
|
||||||
@@ -14,26 +14,10 @@ its storage implementation lives in ``deerflow.runtime.events.store.db`` and
|
|||||||
there is no matching entity directory.
|
there is no matching entity directory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from deerflow.persistence.channel_connections.model import (
|
|
||||||
ChannelConnectionRow,
|
|
||||||
ChannelConversationRow,
|
|
||||||
ChannelCredentialRow,
|
|
||||||
ChannelOAuthStateRow,
|
|
||||||
)
|
|
||||||
from deerflow.persistence.feedback.model import FeedbackRow
|
from deerflow.persistence.feedback.model import FeedbackRow
|
||||||
from deerflow.persistence.models.run_event import RunEventRow
|
from deerflow.persistence.models.run_event import RunEventRow
|
||||||
from deerflow.persistence.run.model import RunRow
|
from deerflow.persistence.run.model import RunRow
|
||||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||||
from deerflow.persistence.user.model import UserRow
|
from deerflow.persistence.user.model import UserRow
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow", "UserRow"]
|
||||||
"ChannelConnectionRow",
|
|
||||||
"ChannelConversationRow",
|
|
||||||
"ChannelCredentialRow",
|
|
||||||
"ChannelOAuthStateRow",
|
|
||||||
"FeedbackRow",
|
|
||||||
"RunEventRow",
|
|
||||||
"RunRow",
|
|
||||||
"ThreadMetaRow",
|
|
||||||
"UserRow",
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -71,15 +71,6 @@ class ThreadMetaStore(abc.ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
|
||||||
"""Move a thread metadata row to a new owner.
|
|
||||||
|
|
||||||
Intended for trusted internal repair/migration paths. No-op if the
|
|
||||||
row does not exist or the caller fails the owner check.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
||||||
"""Check if ``user_id`` has access to ``thread_id``."""
|
"""Check if ``user_id`` has access to ``thread_id``."""
|
||||||
|
|||||||
@@ -127,14 +127,6 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
|||||||
record["updated_at"] = now_iso()
|
record["updated_at"] = now_iso()
|
||||||
await self._store.aput(THREADS_NS, thread_id, record)
|
await self._store.aput(THREADS_NS, thread_id, record)
|
||||||
|
|
||||||
async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
|
||||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_owner")
|
|
||||||
if record is None:
|
|
||||||
return
|
|
||||||
record["user_id"] = owner_user_id
|
|
||||||
record["updated_at"] = now_iso()
|
|
||||||
await self._store.aput(THREADS_NS, thread_id, record)
|
|
||||||
|
|
||||||
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete")
|
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete")
|
||||||
if record is None:
|
if record is None:
|
||||||
|
|||||||
@@ -211,21 +211,6 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
row.updated_at = datetime.now(UTC)
|
row.updated_at = datetime.now(UTC)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def update_owner(
|
|
||||||
self,
|
|
||||||
thread_id: str,
|
|
||||||
owner_user_id: str,
|
|
||||||
*,
|
|
||||||
user_id: str | None | _AutoSentinel = AUTO,
|
|
||||||
) -> None:
|
|
||||||
"""Move a thread metadata row to ``owner_user_id``."""
|
|
||||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_owner")
|
|
||||||
async with self._sf() as session:
|
|
||||||
if not await self._check_ownership(session, thread_id, resolved_user_id):
|
|
||||||
return
|
|
||||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(user_id=owner_user_id, updated_at=datetime.now(UTC)))
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def delete(
|
async def delete(
|
||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ directly from ``deerflow.runtime``.
|
|||||||
|
|
||||||
from .checkpointer import checkpointer_context, get_checkpointer, make_checkpointer, reset_checkpointer
|
from .checkpointer import checkpointer_context, get_checkpointer, make_checkpointer, reset_checkpointer
|
||||||
from .runs import ConflictError, DisconnectMode, RunContext, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
|
from .runs import ConflictError, DisconnectMode, RunContext, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
|
||||||
from .serialization import serialize, serialize_channel_values, serialize_channel_values_for_api, serialize_lc_object, serialize_messages_tuple, strip_data_url_image_blocks
|
from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple
|
||||||
from .store import get_store, make_store, reset_store, store_context
|
from .store import get_store, make_store, reset_store, store_context
|
||||||
from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge
|
from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge
|
||||||
|
|
||||||
@@ -29,10 +29,8 @@ __all__ = [
|
|||||||
# serialization
|
# serialization
|
||||||
"serialize",
|
"serialize",
|
||||||
"serialize_channel_values",
|
"serialize_channel_values",
|
||||||
"serialize_channel_values_for_api",
|
|
||||||
"serialize_lc_object",
|
"serialize_lc_object",
|
||||||
"serialize_messages_tuple",
|
"serialize_messages_tuple",
|
||||||
"strip_data_url_image_blocks",
|
|
||||||
# store
|
# store
|
||||||
"get_store",
|
"get_store",
|
||||||
"make_store",
|
"make_store",
|
||||||
|
|||||||
@@ -21,13 +21,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import threading
|
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
|
|
||||||
from langgraph.types import Checkpointer
|
from langgraph.types import Checkpointer
|
||||||
|
|
||||||
from deerflow.config.app_config import get_app_config
|
from deerflow.config.app_config import get_app_config
|
||||||
from deerflow.config.checkpointer_config import CheckpointerConfig, ensure_config_loaded
|
from deerflow.config.checkpointer_config import CheckpointerConfig
|
||||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -101,7 +100,6 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
|
|||||||
|
|
||||||
_checkpointer: Checkpointer | None = None
|
_checkpointer: Checkpointer | None = None
|
||||||
_checkpointer_ctx = None # open context manager keeping the connection alive
|
_checkpointer_ctx = None # open context manager keeping the connection alive
|
||||||
_checkpointer_lock = threading.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
def get_checkpointer() -> Checkpointer:
|
def get_checkpointer() -> Checkpointer:
|
||||||
@@ -118,18 +116,25 @@ def get_checkpointer() -> Checkpointer:
|
|||||||
if _checkpointer is not None:
|
if _checkpointer is not None:
|
||||||
return _checkpointer
|
return _checkpointer
|
||||||
|
|
||||||
# Config loading can reset both persistence singletons. Keep it outside
|
# Ensure app config is loaded before checking checkpointer config
|
||||||
# this provider lock to avoid cross-provider lock-order inversion.
|
# This prevents returning InMemorySaver when config.yaml actually has a checkpointer section
|
||||||
ensure_config_loaded()
|
# but hasn't been loaded yet
|
||||||
|
from deerflow.config.app_config import _app_config
|
||||||
with _checkpointer_lock:
|
|
||||||
if _checkpointer is not None:
|
|
||||||
return _checkpointer
|
|
||||||
|
|
||||||
from deerflow.config.checkpointer_config import get_checkpointer_config
|
from deerflow.config.checkpointer_config import get_checkpointer_config
|
||||||
|
|
||||||
config = get_checkpointer_config()
|
config = get_checkpointer_config()
|
||||||
|
|
||||||
|
if config is None and _app_config is None:
|
||||||
|
# Only load app config lazily when neither the app config nor an explicit
|
||||||
|
# checkpointer config has been initialized yet. This keeps tests that
|
||||||
|
# intentionally set the global checkpointer config isolated from any
|
||||||
|
# ambient config.yaml on disk.
|
||||||
|
try:
|
||||||
|
get_app_config()
|
||||||
|
except FileNotFoundError:
|
||||||
|
# In test environments without config.yaml, this is expected.
|
||||||
|
pass
|
||||||
|
config = get_checkpointer_config()
|
||||||
if config is None:
|
if config is None:
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
|
||||||
@@ -137,10 +142,8 @@ def get_checkpointer() -> Checkpointer:
|
|||||||
_checkpointer = InMemorySaver()
|
_checkpointer = InMemorySaver()
|
||||||
return _checkpointer
|
return _checkpointer
|
||||||
|
|
||||||
checkpointer_ctx = _sync_checkpointer_cm(config)
|
_checkpointer_ctx = _sync_checkpointer_cm(config)
|
||||||
checkpointer = checkpointer_ctx.__enter__()
|
_checkpointer = _checkpointer_ctx.__enter__()
|
||||||
_checkpointer_ctx = checkpointer_ctx
|
|
||||||
_checkpointer = checkpointer
|
|
||||||
|
|
||||||
return _checkpointer
|
return _checkpointer
|
||||||
|
|
||||||
@@ -152,7 +155,6 @@ def reset_checkpointer() -> None:
|
|||||||
Useful in tests or after a configuration change.
|
Useful in tests or after a configuration change.
|
||||||
"""
|
"""
|
||||||
global _checkpointer, _checkpointer_ctx
|
global _checkpointer, _checkpointer_ctx
|
||||||
with _checkpointer_lock:
|
|
||||||
if _checkpointer_ctx is not None:
|
if _checkpointer_ctx is not None:
|
||||||
try:
|
try:
|
||||||
_checkpointer_ctx.__exit__(None, None, None)
|
_checkpointer_ctx.__exit__(None, None, None)
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ since all mutations happen within the same event loop).
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import bisect
|
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from deerflow.runtime.events.store.base import RunEventStore
|
from deerflow.runtime.events.store.base import RunEventStore
|
||||||
@@ -14,11 +13,7 @@ from deerflow.runtime.events.store.base import RunEventStore
|
|||||||
|
|
||||||
class MemoryRunEventStore(RunEventStore):
|
class MemoryRunEventStore(RunEventStore):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._events: dict[str, list[dict]] = {} # thread_id -> seq-sorted event list
|
self._events: dict[str, list[dict]] = {} # thread_id -> sorted event list
|
||||||
# Messages-only projection of ``_events`` (same dict objects, no copies),
|
|
||||||
# kept in seq order so message pagination is O(log m + page) via bisect
|
|
||||||
# instead of re-scanning every event on each request.
|
|
||||||
self._messages: dict[str, list[dict]] = {} # thread_id -> seq-sorted message list
|
|
||||||
self._seq_counters: dict[str, int] = {} # thread_id -> last assigned seq
|
self._seq_counters: dict[str, int] = {} # thread_id -> last assigned seq
|
||||||
|
|
||||||
def _next_seq(self, thread_id: str) -> int:
|
def _next_seq(self, thread_id: str) -> int:
|
||||||
@@ -50,8 +45,6 @@ class MemoryRunEventStore(RunEventStore):
|
|||||||
"created_at": created_at or datetime.now(UTC).isoformat(),
|
"created_at": created_at or datetime.now(UTC).isoformat(),
|
||||||
}
|
}
|
||||||
self._events.setdefault(thread_id, []).append(record)
|
self._events.setdefault(thread_id, []).append(record)
|
||||||
if category == "message":
|
|
||||||
self._messages.setdefault(thread_id, []).append(record)
|
|
||||||
return record
|
return record
|
||||||
|
|
||||||
async def put(
|
async def put(
|
||||||
@@ -83,20 +76,18 @@ class MemoryRunEventStore(RunEventStore):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
||||||
# ``messages`` is messages-only and seq-sorted, so the seq window is a
|
all_events = self._events.get(thread_id, [])
|
||||||
# contiguous slice located with bisect (O(log m)) rather than a full scan.
|
messages = [e for e in all_events if e["category"] == "message"]
|
||||||
messages = self._messages.get(thread_id, [])
|
|
||||||
|
|
||||||
if before_seq is not None:
|
if before_seq is not None:
|
||||||
# Records with seq < before_seq, then the last `limit` of them.
|
messages = [e for e in messages if e["seq"] < before_seq]
|
||||||
hi = bisect.bisect_left(messages, before_seq, key=lambda e: e["seq"])
|
# Take the last `limit` records
|
||||||
return messages[max(0, hi - limit) : hi]
|
return messages[-limit:]
|
||||||
elif after_seq is not None:
|
elif after_seq is not None:
|
||||||
# Records with seq > after_seq, then the first `limit` of them.
|
messages = [e for e in messages if e["seq"] > after_seq]
|
||||||
lo = bisect.bisect_right(messages, after_seq, key=lambda e: e["seq"])
|
return messages[:limit]
|
||||||
return messages[lo : lo + limit]
|
|
||||||
else:
|
else:
|
||||||
# Return the latest `limit` records, ascending.
|
# Return the latest `limit` records, ascending
|
||||||
return messages[-limit:]
|
return messages[-limit:]
|
||||||
|
|
||||||
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
||||||
@@ -119,11 +110,11 @@ class MemoryRunEventStore(RunEventStore):
|
|||||||
return filtered[-limit:] if len(filtered) > limit else filtered
|
return filtered[-limit:] if len(filtered) > limit else filtered
|
||||||
|
|
||||||
async def count_messages(self, thread_id):
|
async def count_messages(self, thread_id):
|
||||||
return len(self._messages.get(thread_id, []))
|
all_events = self._events.get(thread_id, [])
|
||||||
|
return sum(1 for e in all_events if e["category"] == "message")
|
||||||
|
|
||||||
async def delete_by_thread(self, thread_id):
|
async def delete_by_thread(self, thread_id):
|
||||||
events = self._events.pop(thread_id, [])
|
events = self._events.pop(thread_id, [])
|
||||||
self._messages.pop(thread_id, None)
|
|
||||||
self._seq_counters.pop(thread_id, None)
|
self._seq_counters.pop(thread_id, None)
|
||||||
return len(events)
|
return len(events)
|
||||||
|
|
||||||
@@ -134,6 +125,4 @@ class MemoryRunEventStore(RunEventStore):
|
|||||||
remaining = [e for e in all_events if e["run_id"] != run_id]
|
remaining = [e for e in all_events if e["run_id"] != run_id]
|
||||||
removed = len(all_events) - len(remaining)
|
removed = len(all_events) - len(remaining)
|
||||||
self._events[thread_id] = remaining
|
self._events[thread_id] = remaining
|
||||||
# Keep the message projection in lockstep (same surviving dict objects).
|
|
||||||
self._messages[thread_id] = [e for e in remaining if e["category"] == "message"]
|
|
||||||
return removed
|
return removed
|
||||||
|
|||||||
@@ -164,18 +164,7 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
metadata={"caller": caller, **(metadata or {})},
|
metadata={"caller": caller, **(metadata or {})},
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_chain_end(
|
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
self,
|
|
||||||
outputs: Any,
|
|
||||||
*,
|
|
||||||
run_id: UUID,
|
|
||||||
parent_run_id: UUID | None = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> None:
|
|
||||||
# Nested chain ends fire for internal graph nodes; only the root chain
|
|
||||||
# represents the user-visible run lifecycle.
|
|
||||||
if parent_run_id is not None:
|
|
||||||
return
|
|
||||||
self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"})
|
self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"})
|
||||||
self._flush_sync()
|
self._flush_sync()
|
||||||
|
|
||||||
|
|||||||
@@ -83,7 +83,6 @@ class RunRecord:
|
|||||||
multitask_strategy: str = "reject"
|
multitask_strategy: str = "reject"
|
||||||
metadata: dict = field(default_factory=dict)
|
metadata: dict = field(default_factory=dict)
|
||||||
kwargs: dict = field(default_factory=dict)
|
kwargs: dict = field(default_factory=dict)
|
||||||
user_id: str | None = None
|
|
||||||
created_at: str = ""
|
created_at: str = ""
|
||||||
updated_at: str = ""
|
updated_at: str = ""
|
||||||
task: asyncio.Task | None = field(default=None, repr=False)
|
task: asyncio.Task | None = field(default=None, repr=False)
|
||||||
@@ -119,48 +118,13 @@ class RunManager:
|
|||||||
persistence_retry_policy: PersistenceRetryPolicy | None = None,
|
persistence_retry_policy: PersistenceRetryPolicy | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._runs: dict[str, RunRecord] = {}
|
self._runs: dict[str, RunRecord] = {}
|
||||||
# Secondary index: thread_id -> insertion-ordered run_id set (a dict is
|
|
||||||
# used as an ordered set), maintained in lockstep with ``_runs`` so
|
|
||||||
# per-thread queries avoid O(total in-memory runs) full scans while
|
|
||||||
# preserving ``_runs`` iteration order (see ``_thread_records_locked``).
|
|
||||||
self._runs_by_thread: dict[str, dict[str, None]] = {}
|
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
self._store = store
|
self._store = store
|
||||||
self._persistence_retry_policy = persistence_retry_policy or PersistenceRetryPolicy()
|
self._persistence_retry_policy = persistence_retry_policy or PersistenceRetryPolicy()
|
||||||
|
|
||||||
def _index_run_locked(self, record: RunRecord) -> None:
|
|
||||||
"""Register *record* in the thread index. Caller must hold ``self._lock``."""
|
|
||||||
self._runs_by_thread.setdefault(record.thread_id, {})[record.run_id] = None
|
|
||||||
|
|
||||||
def _unindex_run_locked(self, run_id: str, thread_id: str) -> None:
|
|
||||||
"""Drop *run_id* from the thread index. Caller must hold ``self._lock``."""
|
|
||||||
bucket = self._runs_by_thread.get(thread_id)
|
|
||||||
if bucket is not None:
|
|
||||||
bucket.pop(run_id, None)
|
|
||||||
if not bucket:
|
|
||||||
self._runs_by_thread.pop(thread_id, None)
|
|
||||||
|
|
||||||
def _thread_records_locked(self, thread_id: str) -> list[RunRecord]:
|
|
||||||
"""Return live in-memory records for *thread_id*. Caller must hold ``self._lock``.
|
|
||||||
|
|
||||||
Uses the ``_runs_by_thread`` index for O(runs-in-thread) lookup instead of
|
|
||||||
scanning every in-memory run. Correctness rests on the index and ``_runs``
|
|
||||||
being mutated in lockstep under ``self._lock`` (no ``await`` between the two
|
|
||||||
writes), so any holder of the lock sees them agree. The ``self._runs.get``
|
|
||||||
filter is defense-in-depth, not reconciliation: it drops a stale id still in
|
|
||||||
the index but already gone from ``_runs``, yet it cannot recover a run that is
|
|
||||||
in ``_runs`` but missing from the index (such a run would be silently
|
|
||||||
omitted). It guards only that one direction, should a future refactor ever
|
|
||||||
break the lockstep invariant.
|
|
||||||
"""
|
|
||||||
run_ids = self._runs_by_thread.get(thread_id)
|
|
||||||
if not run_ids:
|
|
||||||
return []
|
|
||||||
return [record for run_id in run_ids if (record := self._runs.get(run_id)) is not None]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]:
|
def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]:
|
||||||
payload = {
|
return {
|
||||||
"thread_id": record.thread_id,
|
"thread_id": record.thread_id,
|
||||||
"assistant_id": record.assistant_id,
|
"assistant_id": record.assistant_id,
|
||||||
"status": record.status.value,
|
"status": record.status.value,
|
||||||
@@ -171,9 +135,6 @@ class RunManager:
|
|||||||
"created_at": record.created_at,
|
"created_at": record.created_at,
|
||||||
"model_name": record.model_name,
|
"model_name": record.model_name,
|
||||||
}
|
}
|
||||||
if record.user_id is not None:
|
|
||||||
payload["user_id"] = record.user_id
|
|
||||||
return payload
|
|
||||||
|
|
||||||
async def _call_store_with_retry(
|
async def _call_store_with_retry(
|
||||||
self,
|
self,
|
||||||
@@ -280,7 +241,6 @@ class RunManager:
|
|||||||
kwargs=row.get("kwargs") or {},
|
kwargs=row.get("kwargs") or {},
|
||||||
created_at=row.get("created_at") or "",
|
created_at=row.get("created_at") or "",
|
||||||
updated_at=row.get("updated_at") or "",
|
updated_at=row.get("updated_at") or "",
|
||||||
user_id=row.get("user_id"),
|
|
||||||
error=row.get("error"),
|
error=row.get("error"),
|
||||||
model_name=row.get("model_name"),
|
model_name=row.get("model_name"),
|
||||||
store_only=True,
|
store_only=True,
|
||||||
@@ -360,7 +320,6 @@ class RunManager:
|
|||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
kwargs: dict | None = None,
|
kwargs: dict | None = None,
|
||||||
multitask_strategy: str = "reject",
|
multitask_strategy: str = "reject",
|
||||||
user_id: str | None = None,
|
|
||||||
) -> RunRecord:
|
) -> RunRecord:
|
||||||
"""Create a new pending run and register it."""
|
"""Create a new pending run and register it."""
|
||||||
run_id = str(uuid.uuid4())
|
run_id = str(uuid.uuid4())
|
||||||
@@ -374,13 +333,11 @@ class RunManager:
|
|||||||
multitask_strategy=multitask_strategy,
|
multitask_strategy=multitask_strategy,
|
||||||
metadata=metadata or {},
|
metadata=metadata or {},
|
||||||
kwargs=kwargs or {},
|
kwargs=kwargs or {},
|
||||||
user_id=user_id,
|
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
)
|
)
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._runs[run_id] = record
|
self._runs[run_id] = record
|
||||||
self._index_run_locked(record)
|
|
||||||
persisted = False
|
persisted = False
|
||||||
try:
|
try:
|
||||||
await self._persist_new_run_to_store(record)
|
await self._persist_new_run_to_store(record)
|
||||||
@@ -392,7 +349,6 @@ class RunManager:
|
|||||||
# Also covers cancellation, which bypasses ``except Exception``.
|
# Also covers cancellation, which bypasses ``except Exception``.
|
||||||
if not persisted:
|
if not persisted:
|
||||||
self._runs.pop(run_id, None)
|
self._runs.pop(run_id, None)
|
||||||
self._unindex_run_locked(run_id, record.thread_id)
|
|
||||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
@@ -448,7 +404,8 @@ class RunManager:
|
|||||||
limit: Maximum number of runs to return.
|
limit: Maximum number of runs to return.
|
||||||
"""
|
"""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
memory_records = self._thread_records_locked(thread_id)
|
# Dict insertion order gives deterministic results when timestamps tie.
|
||||||
|
memory_records = [r for r in self._runs.values() if r.thread_id == thread_id]
|
||||||
if self._store is None:
|
if self._store is None:
|
||||||
return sorted(memory_records, key=lambda r: r.created_at, reverse=True)[:limit]
|
return sorted(memory_records, key=lambda r: r.created_at, reverse=True)[:limit]
|
||||||
records_by_id = {record.run_id: record for record in memory_records}
|
records_by_id = {record.run_id: record for record in memory_records}
|
||||||
@@ -547,7 +504,6 @@ class RunManager:
|
|||||||
kwargs: dict | None = None,
|
kwargs: dict | None = None,
|
||||||
multitask_strategy: str = "reject",
|
multitask_strategy: str = "reject",
|
||||||
model_name: str | None = None,
|
model_name: str | None = None,
|
||||||
user_id: str | None = None,
|
|
||||||
) -> RunRecord:
|
) -> RunRecord:
|
||||||
"""Atomically check for inflight runs and create a new one.
|
"""Atomically check for inflight runs and create a new one.
|
||||||
|
|
||||||
@@ -568,7 +524,7 @@ class RunManager:
|
|||||||
if multitask_strategy not in _supported_strategies:
|
if multitask_strategy not in _supported_strategies:
|
||||||
raise UnsupportedStrategyError(f"Multitask strategy '{multitask_strategy}' is not yet supported. Supported strategies: {', '.join(_supported_strategies)}")
|
raise UnsupportedStrategyError(f"Multitask strategy '{multitask_strategy}' is not yet supported. Supported strategies: {', '.join(_supported_strategies)}")
|
||||||
|
|
||||||
inflight = [r for r in self._thread_records_locked(thread_id) if r.status in (RunStatus.pending, RunStatus.running)]
|
inflight = [r for r in self._runs.values() if r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running)]
|
||||||
|
|
||||||
if multitask_strategy == "reject" and inflight:
|
if multitask_strategy == "reject" and inflight:
|
||||||
raise ConflictError(f"Thread {thread_id} already has an active run")
|
raise ConflictError(f"Thread {thread_id} already has an active run")
|
||||||
@@ -590,13 +546,11 @@ class RunManager:
|
|||||||
multitask_strategy=multitask_strategy,
|
multitask_strategy=multitask_strategy,
|
||||||
metadata=metadata or {},
|
metadata=metadata or {},
|
||||||
kwargs=kwargs or {},
|
kwargs=kwargs or {},
|
||||||
user_id=user_id,
|
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
)
|
)
|
||||||
self._runs[run_id] = record
|
self._runs[run_id] = record
|
||||||
self._index_run_locked(record)
|
|
||||||
persisted = False
|
persisted = False
|
||||||
try:
|
try:
|
||||||
await self._persist_new_run_to_store(record)
|
await self._persist_new_run_to_store(record)
|
||||||
@@ -608,7 +562,6 @@ class RunManager:
|
|||||||
# Also covers cancellation, which bypasses ``except Exception``.
|
# Also covers cancellation, which bypasses ``except Exception``.
|
||||||
if not persisted:
|
if not persisted:
|
||||||
self._runs.pop(run_id, None)
|
self._runs.pop(run_id, None)
|
||||||
self._unindex_run_locked(run_id, record.thread_id)
|
|
||||||
|
|
||||||
if multitask_strategy in ("interrupt", "rollback") and inflight:
|
if multitask_strategy in ("interrupt", "rollback") and inflight:
|
||||||
for r in inflight:
|
for r in inflight:
|
||||||
@@ -682,16 +635,14 @@ class RunManager:
|
|||||||
async def has_inflight(self, thread_id: str) -> bool:
|
async def has_inflight(self, thread_id: str) -> bool:
|
||||||
"""Return ``True`` if *thread_id* has a pending or running run."""
|
"""Return ``True`` if *thread_id* has a pending or running run."""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
return any(r.status in (RunStatus.pending, RunStatus.running) for r in self._thread_records_locked(thread_id))
|
return any(r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running) for r in self._runs.values())
|
||||||
|
|
||||||
async def cleanup(self, run_id: str, *, delay: float = 300) -> None:
|
async def cleanup(self, run_id: str, *, delay: float = 300) -> None:
|
||||||
"""Remove a run record after an optional delay."""
|
"""Remove a run record after an optional delay."""
|
||||||
if delay > 0:
|
if delay > 0:
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
record = self._runs.pop(run_id, None)
|
self._runs.pop(run_id, None)
|
||||||
if record is not None:
|
|
||||||
self._unindex_run_locked(run_id, record.thread_id)
|
|
||||||
logger.debug("Run record %s cleaned up", run_id)
|
logger.debug("Run record %s cleaned up", run_id)
|
||||||
|
|
||||||
async def shutdown(self, *, timeout: float = 5.0) -> None:
|
async def shutdown(self, *, timeout: float = 5.0) -> None:
|
||||||
|
|||||||
@@ -56,56 +56,6 @@ def serialize_channel_values(channel_values: dict[str, Any]) -> dict[str, Any]:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def strip_data_url_image_blocks(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
||||||
"""Remove ``data:``-scheme ``image_url`` blocks from *hide_from_ui* messages.
|
|
||||||
|
|
||||||
The history and run-wait endpoints return checkpoint-persisted messages to
|
|
||||||
the frontend. ``ViewImageMiddleware`` stores full base64 image payloads in
|
|
||||||
``hide_from_ui`` human messages — these are internal model context and must
|
|
||||||
not be sent over the wire (huge response bodies, no UI value).
|
|
||||||
|
|
||||||
Only content blocks of type ``image_url`` whose URL starts with ``data:``
|
|
||||||
are stripped. Text blocks, ``https://`` image URLs, and non-hidden
|
|
||||||
messages are left untouched so that message ordering and count are
|
|
||||||
preserved.
|
|
||||||
"""
|
|
||||||
result: list[dict[str, Any]] = []
|
|
||||||
for msg in messages:
|
|
||||||
if not isinstance(msg, dict):
|
|
||||||
result.append(msg)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Only touch messages explicitly flagged as hidden from the UI.
|
|
||||||
additional_kwargs = msg.get("additional_kwargs")
|
|
||||||
if not (isinstance(additional_kwargs, dict) and additional_kwargs.get("hide_from_ui") is True):
|
|
||||||
result.append(msg)
|
|
||||||
continue
|
|
||||||
|
|
||||||
content = msg.get("content")
|
|
||||||
if not isinstance(content, list):
|
|
||||||
result.append(msg)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Filter out image_url blocks with data: scheme.
|
|
||||||
filtered = [block for block in content if not (isinstance(block, dict) and block.get("type") == "image_url" and isinstance(block.get("image_url"), dict) and str(block["image_url"].get("url", "")).startswith("data:"))]
|
|
||||||
result.append({**msg, "content": filtered})
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_channel_values_for_api(channel_values: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Serialize channel values and strip base64 image data from messages.
|
|
||||||
|
|
||||||
Convenience wrapper combining :func:`serialize_channel_values` with
|
|
||||||
:func:`strip_data_url_image_blocks`. Use this in all REST endpoints
|
|
||||||
that return channel values to the frontend so that ``data:``-scheme
|
|
||||||
base64 image payloads are never sent over the wire.
|
|
||||||
"""
|
|
||||||
result = serialize_channel_values(channel_values)
|
|
||||||
if isinstance(result.get("messages"), list):
|
|
||||||
result["messages"] = strip_data_url_image_blocks(result["messages"])
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_messages_tuple(obj: Any) -> Any:
|
def serialize_messages_tuple(obj: Any) -> Any:
|
||||||
"""Serialize a messages-mode tuple ``(chunk, metadata)``."""
|
"""Serialize a messages-mode tuple ``(chunk, metadata)``."""
|
||||||
if isinstance(obj, tuple) and len(obj) == 2:
|
if isinstance(obj, tuple) and len(obj) == 2:
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user