Compare commits

...

53 Commits

Author SHA1 Message Date
taohe e0cb0df5da Persist disconnected IM channel state 2026-06-11 22:46:36 +08:00
DanielWalnut f3264ea5a9 Merge branch 'main' into codex/im-channel-connections 2026-06-11 22:32:36 +08:00
taohe b17bc5efc9 Ensure runtime IM channels are ready after restart 2026-06-11 22:19:03 +08:00
zgenu c733d3c917 fix(frontend): isolate new chat thread messages (#3508)
* fix(frontend): isolate new chat thread messages

* fix(frontend): keep live messages visible in new chat

* fix(frontend): reset thread-local message refs
2026-06-11 22:12:15 +08:00
Huixin615 b6fbf0d105 fix(frontend): keep workspace interactive when SSR auth probe cannot reach gateway (#3493) (#3495)
* fix(frontend): keep workspace interactive when SSR auth probe cannot reach gateway (#3493)

When the SSR auth probe at /api/v1/auth/me times out or fails, the
workspace layout used to render a static fallback page without
AuthProvider or QueryClientProvider, making logout and every other
interaction non-functional until the gateway recovered.

Render the normal WorkspaceContent in 'gateway_unavailable' mode
instead, surfacing a polite offline banner that re-probes the gateway
in the background and hides itself the moment refreshUser() returns
an authenticated user. The probe is reentrancy-guarded so a slow
gateway cannot pile up parallel /auth/me requests.

Closes #3493

* fix(workspace): silent probe in offline banner to avoid /login redirect during gateway recovery (#3493)

The banner previously delegated retry probes to AuthProvider.refreshUser(),
which treats any 401 from /api/v1/auth/me as 'session expired' and
force-redirects to /login. During gateway recovery, the first few requests
may transiently return 401 before the gateway is fully ready, which would
incorrectly kick the user out — defeating the purpose of the offline banner.

Now the banner silently fetches /api/v1/auth/me itself and only delegates
to refreshUser() on 200 OK. Non-200 responses (401 / 5xx / network) are
swallowed and retried on the next interval tick, ensuring the user stays
logged in across short gateway outages.

Verified in Docker:
- docker pause deer-flow-gateway → banner appears, page interactive
- docker unpause deer-flow-gateway → banner auto-disappears within 10s,
  user remains on /workspace/chats/new with full session restored
- All 117 unit tests pass

* fix(workspace): fix banner polling leak and persistent 401 handling (#3493)
- Stop polling immediately after user recovery: add user to effect dependencies, cleanup interval when user !== null
- Handle persistent 401: trigger login redirect after 3 consecutive unauthorized responses
- Extract decision logic to pure helper, add 8 unit tests covering all critical paths

* fix(workspace): address CR feedback on gateway offline recovery (#3493)

- gateway-offline-banner-helpers: decrement (not reset) auth-failure
  streak on transient outcomes so a flapping gateway (401 alternating
  with 5xx) still converges on session-expired
- gateway-offline-banner: reuse probe response body to apply user
  directly via new AuthProvider.applyUser, halving the recovery burst
  against an already-struggling gateway
- gateway-offline-banner: extract classifyProbe into helpers for unit
  testability; log probe failures via console.warn instead of swallowing
- gateway-offline-fallback: new shared component used by both workspace
  and (auth) layouts so auth pages recover the same way the workspace
  does, fixing the lockup where bare static HTML had no AuthProvider
- AuthProvider.logout: fall back to hard navigation when the gateway
  logout fetch fails, matching legacy form-POST behaviour and avoiding
  stale client state during outage
- tests: extend gateway-offline-banner-helpers.test with flapping
  convergence and classifyProbe branch coverage (19 cases total)
2026-06-11 21:14:49 +08:00
taohe 7ce047b723 Merge remote-tracking branch 'origin/main' into codex/im-channel-connections 2026-06-11 21:13:25 +08:00
taohe 0268ebbe6a Use sha256 user buckets with legacy migration 2026-06-11 21:08:08 +08:00
taohe 0798489734 Address channel connection review comments 2026-06-11 20:54:57 +08:00
taohe 5e3cc83d17 Isolate channel runtime config tests 2026-06-11 18:15:53 +08:00
taohe 6956b247fd Stabilize backend tests without local config 2026-06-11 18:07:06 +08:00
taohe b8323024c9 Fix frontend formatting after merge 2026-06-11 17:54:22 +08:00
DanielWalnut f401e7baa6 [codex] Fix stale AIO sandbox cache reuse (#3494)
* Fix stale AIO sandbox cache reuse

* Address AIO sandbox review feedback

* Distinguish sandbox health check failures

* Keep local discovery recoverable when the runtime check fails

LocalContainerBackend.discover() shares _is_container_running, which now
raises on transient daemon errors instead of returning False. Discovery has
no exception handling in _discover_or_create_with_lock(_async), so a brief
Docker hiccup turned a recoverable "could not verify, create instead" into a
hard acquire failure. Catch the check failure inside discover() and return
None so an unverifiable container is simply not adopted, restoring the
pre-change fall-through while keeping raise-on-unknown semantics protecting
the destroy path.

Reported by fancy-agent on PR #3494.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

* Narrow the not-found match in container inspect error handling

A bare "not found" substring also matches transient failures like "command
not found" or "context not found", which would misclassify a check error as
"container definitely gone" and bypass the raise-on-unknown contract. Keep
Docker's specific "No such object"/"No such container" phrases, and only
trust a generic "not found" (Apple Container) when the message names the
inspected container or refers to a container/object.

Reported by WillemJiang on PR #3494.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

---------

Co-authored-by: Claude Fable 5 <noreply@anthropic.com>
2026-06-11 17:53:37 +08:00
taohe d1768606c0 Merge remote-tracking branch 'origin/main' into codex/im-channel-connections
# Conflicts:
#	backend/app/gateway/services.py
#	frontend/src/app/workspace/chats/page.tsx
2026-06-11 17:51:16 +08:00
Huixin615 919d8bc279 fix(sandbox): persist lazily-acquired sandbox state via Command (#3464)
* fix(sandbox): persist lazily-acquired sandbox state via Command

ensure_sandbox_initialized mutates runtime.state in place, which is local
to the current tool invocation and is not picked up by LangGraph's channel
reducer. Subsequent graph steps and downstream consumers (such as
ToolOutputBudgetMiddleware and the sub-agent task_tool) therefore cannot
observe the sandbox id from state.

Wrap tool calls in SandboxMiddleware (wrap_tool_call / awrap_tool_call) to
detect fresh lazy initialization by diffing runtime.state before and after
the handler, and emit a proper state update via Command(update=...):

- ToolMessage results are wrapped into Command(update={sandbox, messages})
- Command results with a dict update are merged on the sandbox key while
  preserving messages / goto / graph / resume
- Command results with non-dict updates are left untouched to avoid silent
  data loss on unknown update shapes

Tests:
- 7 new unit tests cover lazy-init emit, passthrough, dict-update merge,
  non-dict-update passthrough (sync and async)
- Refresh replay golden write_read_file.ultra.events.json: SSE 'values'
  events now correctly carry the 'sandbox' key in their keys list, which
  is the direct evidence that the fix is effective

Closes #3463

* refactor(sandbox): use dataclasses.replace to preserve Command fields

Address Copilot review on #3464: replace manual field-copy with
dataclasses.replace so any current or future Command fields are
preserved automatically when merging sandbox_update.

Also add a regression test that constructs a Command with non-None
graph/goto/resume to lock this behavior in.
2026-06-11 17:50:36 +08:00
taohe ddd1c5e42f Let setup wizard enable IM channels 2026-06-11 17:34:22 +08:00
taohe a270e8b310 Ignore Feishu non-content message events 2026-06-11 17:22:16 +08:00
taohe f330ddce01 Ignore Feishu message read events 2026-06-11 17:15:44 +08:00
taohe b26b30ac3d Reflect IM channel runtime health 2026-06-11 17:11:55 +08:00
taohe dae7c7870e Prefill IM channel runtime config 2026-06-11 17:02:02 +08:00
taohe 4f56437030 Show IM channel source on threads 2026-06-11 16:51:04 +08:00
taohe 42fd0cc22f Use default user for auth-disabled local mode 2026-06-11 16:33:37 +08:00
taohe a4202028d9 Route no-auth channel sessions to local user 2026-06-11 16:22:14 +08:00
taohe 4a0278420f Allow disconnecting runtime IM channels 2026-06-11 16:10:02 +08:00
taohe ade4a55cfe Persist IM runtime config locally 2026-06-11 15:58:40 +08:00
taohe 09872af36c Make channel threads visible to connection owners 2026-06-11 15:40:49 +08:00
taohe 92f562920d Avoid password autofill for channel secrets 2026-06-11 15:12:18 +08:00
taohe 9d51e38641 Keep configured IM channels editable 2026-06-11 14:40:37 +08:00
taohe c966eb71a7 Guard global shortcut key handling 2026-06-11 13:57:56 +08:00
taohe c4368c9018 Add runtime setup for enabled IM channels 2026-06-11 12:10:16 +08:00
taohe f83767bb17 Fix IM channel provider icons 2026-06-11 11:48:08 +08:00
taohe 0e939bfe23 Keep unavailable channel connect buttons clickable 2026-06-11 11:28:56 +08:00
taohe 89da9b70db Format additional channel connection tests 2026-06-11 11:20:39 +08:00
taohe a52deada8b Support all integrated IM channel connections 2026-06-11 11:19:27 +08:00
taohe b7097baaec Address IM channel review comments 2026-06-11 10:33:44 +08:00
taohe 87200ff920 Address Copilot IM channel feedback 2026-06-11 08:26:48 +08:00
Willem Jiang 2d5f0787de Update lint-check.yml with the job setting 2026-06-11 00:07:36 +08:00
Huixin615 5819bd8a59 fix(frontend): paginate workspace chat list beyond 50 threads (#3482) (#3485)
* fix(frontend): paginate workspace chat list beyond 50 threads (#3482)

The sidebar 'Recent chats' and /workspace/chats list were hard-capped
at the first 50 threads returned by threads.search. Replace the
single-shot useThreads() consumers with useInfiniteThreads() and add
an IntersectionObserver sentinel to each list so further pages are
fetched on demand.

In search mode on the chats page, the sentinel is replaced by an
explicit 'Load more' button to prevent the observer from draining the
entire backend list while the filtered view stays empty.

- Add useInfiniteThreads + page-size constant and pure cache helpers
  (map/filterInfiniteThreadsCache, getInfiniteThreadsNextPageParam)
- Mirror rename / delete / stream-finish updates into the new
  infinite cache so optimistic UI stays consistent
- Extend the e2e mock to honour limit/offset slicing
- Unit tests for the cache helpers and pagination boundary
- Playwright e2e covering chats page + sidebar load-more, and the
  search-mode guard against runaway auto-pagination
- Add en/zh i18n entries for the search-mode load-more button

Fixes #3482

* docs(frontend): clarify infinite-threads offset semantics and test post-delete invariant

- Add docstring to getInfiniteThreadsNextPageParam explaining that TanStack
  Query freezes the returned offset into pageParams once, so optimistic cache
  mutations that shrink page lengths (filterInfiniteThreadsCache on delete)
  cannot retroactively move the offset backwards. Delete/rename paths
  reconcile against the backend via invalidateQueries in onSettled.
- Add unit test covering the post-delete invariant.
- Fix misleading comment in thread-list-infinite-scroll.spec.ts: the
  thread-search mock does not sort by updated_at; it returns the array in
  the order provided.

Addresses Copilot CR comments on #3485.

* fix(frontend): mirror onCreated upsert into infinite cache; add sidebar Load-older button

Address review feedback on #3485:

- New upsertThreadInInfiniteCache helper; useThreadStream onCreated now
  upserts into both the legacy ['threads','search'] cache and the new
  infinite cache, so a freshly created thread appears in the sidebar
  immediately during streaming instead of only after the run finishes
  and onSettled invalidates the query. Restores parity with main.
- Sidebar Recent Chats now exposes a visible 'Load older chats' button
  alongside the IntersectionObserver sentinel, so keyboard-only users
  and environments where IO is unavailable can still reach older
  conversations.
- Add zh-CN / en-US / types entry for chats.loadOlderChats.
- Cover the new helper with 3 unit tests (no-op on uninitialised cache,
  prepend new thread to first page, merge with existing entry without
  duplication).
2026-06-10 23:59:38 +08:00
hataa b3c2cc42cf fix(agents): require config.yaml in resolve_agent_dir to skip memory-only directories (#3390) (#3481)
When memory is enabled, the first conversation with a legacy shared agent
creates a per-user agent directory containing only memory.json (no
config.yaml). On the second turn, resolve_agent_dir() returned this
incomplete directory, causing load_agent_config() to fail with
"Agent config not found".

Require config.yaml to exist alongside the directory for both the
per-user and legacy paths, so that memory-only directories fall
through correctly. This aligns resolve_agent_dir with the existing
config.yaml check in list_custom_agents.

Refs: https://github.com/bytedance/deer-flow/issues/3390
2026-06-10 23:57:17 +08:00
Ryker_Feng 167ef4512f feat(memory): add memory.token_counting config to avoid tiktoken network dependency (#3429) (#3465)
* feat(memory): add memory.token_counting config to avoid tiktoken network dependency (#3429)

Add a `memory.token_counting` option (`tiktoken` | `char`) so deployments in
network-restricted environments can opt out of tiktoken entirely. In `char`
mode the memory-injection budget uses a network-free character-based estimate
and never triggers the BPE download from openaipublic.blob.core.windows.net,
which could otherwise block for tens of minutes (see #3402).

Also harden the default `tiktoken` path:
- cache an in-flight LOADING sentinel so concurrent callers fall back
  immediately instead of spawning more blocking get_encoding threads when the
  first load is still running (e.g. under the 5s startup warm-up timeout);
- cache failures with a timestamp and retry after a cooldown so a transient
  network outage self-heals back to accurate counting without a restart;
- skip startup warm-up entirely in char mode.

The new config is surfaced via the memory config API and config.example.yaml
(config_version bumped). Default remains `tiktoken`, so existing deployments
are unaffected.

* fix(memory): use CJK-aware char token estimate and address review feedback

- Replace the flat len(text)//4 fallback with a CJK-aware estimate so
  Chinese/Japanese/Korean memory content does not over-fill the injection budget
- Document the internal tiktoken retry cooldown and char-mode escape hatch
- Sync CLAUDE.md / config.example.yaml / MEMORY_IMPROVEMENTS.md wording
- Fix MemoryConfigResponse mocks/assertions and add CJK estimate tests
2026-06-10 23:26:15 +08:00
Xinmin Zeng ba9cc5e972 fix(gateway): enforce thread ownership on stateless run endpoints (#3473)
POST /api/runs/stream and /api/runs/wait accept thread_id in the request
body but performed no owner authorization, letting any authenticated user
start runs on -- and read /wait checkpoint channel_values from -- another
user's thread (cross-user IDOR, #3472).

The @require_permission(owner_check=True) decorator resolves ownership from
the thread_id *path* param, so it cannot cover these body-param endpoints.
Enforce ownership inside start_run() before create_or_reject via
ThreadMetaStore.check_access: missing rows (auto-created temp threads) and
NULL-owner rows stay accessible, while a thread owned by another user
returns 404 (matching thread_runs.py). The internal system role (IM
channels acting for platform users) is exempt.

Closes #3472
2026-06-10 23:03:39 +08:00
taohe 6a94b58ad1 Fix safe user id digest algorithm 2026-06-10 22:53:07 +08:00
taohe d06643d8a2 Align IM connections with local channels 2026-06-10 22:16:47 +08:00
taohe 92c185b90d Support local IM channel connections 2026-06-10 21:59:33 +08:00
taohe 9effa7be6d Merge remote-tracking branch 'origin/main' into codex/im-channel-connections 2026-06-10 21:42:12 +08:00
taohe 582bfda6f8 Harden dev service daemon startup 2026-06-10 21:41:40 +08:00
Xinmin Zeng 05ae4467ae fix(docker): default Gateway to a single worker to prevent multi-worker breakage (#3475)
The default `make up` started the Gateway with `--workers 4`, but run state
(RunManager and the stream bridge) is held in-process and nginx uses no sticky
sessions. With the default config, same-run requests scatter across workers that
each keep their own run state, breaking run cancellation (409), SSE reconnect
(hangs on heartbeats), multitask de-duplication, and IM channels (duplicate
replies). The shared cross-worker stream bridge does not exist yet.

Default GATEWAY_WORKERS to 1 so the out-of-the-box deployment is correct,
document the single-worker boundary in the README, and add a regression test
pinning the default while keeping it overridable. This is a stop-gap, not a
multi-worker implementation; the full fix (shared run state + stream bridge) is
tracked in #3191.

Refs #3239, #3260
2026-06-10 21:36:25 +08:00
taohe b66152c514 Use async channel connect flow 2026-06-10 21:34:29 +08:00
taohe 78fbc0abdb Fix dev startup and channel connect popup 2026-06-10 21:33:15 +08:00
taohe ec5ed185cd Merge remote-tracking branch 'origin/main' into codex/im-channel-connections
# Conflicts:
#	backend/app/channels/discord.py
#	backend/app/channels/manager.py
#	backend/app/channels/slack.py
#	backend/app/channels/telegram.py
2026-06-10 21:13:02 +08:00
taohe dbe3a3bb0d Add user-owned IM channel connections 2026-06-10 21:07:44 +08:00
DanielWalnut 2b795265e7 fix: align auth-disabled mode and mock history loading (#3471)
* fix: align auth-disabled mode and mock history loading

* fix: address auth-disabled review feedback

* test: cover auth-disabled backend contract

* style: format frontend tests

* fix: address follow-up review comments
2026-06-10 16:11:00 +08:00
Nan Gao a57d05fe0a fix runtime journal run lifecycle events (#3470) 2026-06-10 08:33:29 +08:00
Lucy Shen ae9e8bc0bf fix(sandbox): make missing sandbox.mounts host_path a loud ERROR (#3244) (#3250)
In Docker production deployments, LocalSandboxProvider runs inside the
deer-flow-gateway container, so any `sandbox.mounts[].host_path` from
config.yaml is resolved against the gateway container's filesystem — not
the host machine. When the path isn't also bind-mounted into the gateway
service, the mount was silently dropped with only a WARNING log, leaving
agents reading an empty directory in production while the same config
worked under `make dev`.

Escalate the missing-host_path branch to logger.error with explicit
guidance about Docker bind mounts and docker-compose, so the failure is
hard to miss in default log configurations. Skip behaviour is preserved
to avoid breaking existing deployments.

Also clarify the misleading `VolumeMountConfig.host_path` field
description so it documents reality for both providers:

  - LocalSandboxProvider checks host_path from inside the gateway process
    (host in `make dev`, container in `make up`).
  - AioSandboxProvider (DooD) passes host_path straight to `docker -v`
    for the sandbox container, where the host Docker daemon resolves it
    from the host machine's perspective.

config.example.yaml's `sandbox.mounts` comment gets a Note: block
pointing operators at the docker-compose bind-mount requirement so the
Docker-mode gotcha is discoverable from the canonical template.

Adds a regression test that:
  - confirms missing host_path is still skipped (no behaviour break);
  - asserts an ERROR record is emitted referencing the offending paths;
  - asserts the message contains actionable Docker/gateway/docker-compose
    keywords so future refactors can't quietly downgrade it.

Refs: https://github.com/bytedance/deer-flow/issues/3244
2026-06-09 23:16:14 +08:00
143 changed files with 11193 additions and 478 deletions
+1 -1
View File
@@ -10,7 +10,7 @@ permissions:
contents: read
jobs:
lint:
lint-backend:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
+5
View File
@@ -247,6 +247,9 @@ Access: http://localhost:2026
The unified nginx endpoint is same-origin by default and does not emit browser CORS headers. If you run a split-origin or port-forwarded browser client, set `GATEWAY_CORS_ORIGINS` to comma-separated exact origins such as `http://localhost:3000`; the Gateway then applies the CORS allowlist and matching CSRF origin checks.
> [!IMPORTANT]
> The Gateway holds run state (RunManager and the stream bridge) in process, so production defaults to a single Gateway worker (`GATEWAY_WORKERS=1`). Raising the worker count without a shared cross-worker stream bridge — which is not yet available — breaks run cancellation, SSE reconnects, request de-duplication, and IM channels, because nginx uses no sticky sessions and each worker keeps its own run state. Scale a single worker up with more CPU/RAM (or move the database and sandbox onto dedicated tiers) instead of raising `GATEWAY_WORKERS`.
See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
#### Option 2: Local Development
@@ -340,6 +343,8 @@ See the [MCP Server Guide](backend/docs/MCP_SERVER.md) for detailed instructions
DeerFlow supports receiving tasks from messaging apps. Channels auto-start when configured — no public IP required for any of them.
DeerFlow can also expose user-owned IM channel connections in the workspace UI. When `channel_connections` is enabled, logged-in users can bind Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, or WeCom from the sidebar / Settings > Channels. It reuses the existing outbound `channels.*` transports, so no public IP or provider callback URL is required. Incoming IM messages then run under the connected DeerFlow user account. See [IM Channel Connections](backend/docs/IM_CHANNEL_CONNECTIONS.md) for setup and security notes.
| Channel | Transport | Difficulty |
|---------|-----------|------------|
| Telegram | Bot API (long-polling) | Easy |
+30 -11
View File
@@ -284,7 +284,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
**Provider Pattern**: `SandboxProvider` with `acquire`, `acquire_async`, `get`, `release` lifecycle. Async agent/tool paths call async sandbox lifecycle hooks so Docker sandbox creation, discovery, cross-process locking, readiness polling, and release stay off the event loop.
**Implementations**:
- `LocalSandboxProvider` - Local filesystem execution. `acquire(thread_id)` returns a per-thread `LocalSandbox` (id `local:{thread_id}`) whose `path_mappings` resolve `/mnt/user-data/{workspace,uploads,outputs}` and `/mnt/acp-workspace` to that thread's host directories, so the public `Sandbox` API honours the `/mnt/user-data` contract uniformly with AIO. `acquire()` / `acquire(None)` keeps the legacy generic singleton (id `local`) for callers without a thread context. Per-thread sandboxes are held in an LRU cache (default 256 entries) guarded by a `threading.Lock`.
- `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation
- `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation. Active-cache and warm-pool entries are checked with the backend during acquire/reuse; definitively dead containers are dropped from all in-process maps so the thread can discover or create a fresh sandbox instead of reusing a stale client. Backend health-check failures are treated as unknown, not dead; local discovery likewise treats an unverifiable container as not adoptable and falls through to create rather than failing acquire. `get()` remains an in-memory lookup for event-loop-safe tool paths.
**Virtual Path System**:
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
@@ -369,8 +369,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
### IM Channels System (`app/channels/`)
Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the DeerFlow agent via Gateway's LangGraph-compatible API.
Bridges external messaging platforms (Feishu, Slack, Telegram, Discord, DingTalk) to the DeerFlow agent via Gateway's LangGraph-compatible API.
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies.
@@ -380,18 +379,21 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
- `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Telegram on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu incremental outbound updates
- `base.py` - Abstract `Channel` base class (start/stop/send lifecycle)
- `service.py` - Manages lifecycle of all configured channels from `config.yaml`
- `slack.py` / `feishu.py` / `telegram.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured)
- `slack.py` / `feishu.py` / `telegram.py` / `discord.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured)
- `app/gateway/routers/channel_connections.py` - Browser-facing user connection and disconnect APIs
- `deerflow.persistence.channel_connections` - SQL-backed user-owned connection, optional credential, connect state, and conversation store
**Message Flow**:
1. External platform -> Channel impl -> `MessageBus.publish_inbound()`
2. `ChannelManager._dispatch_loop()` consumes from queue
3. For chat: look up/create thread through Gateway's LangGraph-compatible API
4. Feishu chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`)
5. Slack/Telegram chat: `runs.wait()`extract final response → publish outbound
6. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement)
7. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails
8. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
9. Outbound → channel callbacks → platform reply
3. For user-owned channel connections, incoming messages carry `connection_id`, `owner_user_id`, and `workspace_id`; `owner_user_id` becomes the DeerFlow run `user_id`, while the raw platform user id remains `channel_user_id`
4. For chat: look up/create thread through Gateway's LangGraph-compatible API
5. Feishu chat: `runs.stream()`accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`)
6. Slack/Telegram chat: `runs.wait()` → extract final response → publish outbound
7. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement)
8. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails
9. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
10. Outbound → channel callbacks → platform reply
**Configuration** (`config.yaml` -> `channels`):
- `langgraph_url` - LangGraph-compatible Gateway API base URL (default: `http://localhost:8001/api`)
@@ -399,6 +401,16 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
- In Docker Compose, IM channels run inside the `gateway` container, so `localhost` points back to that container. Use `http://gateway:8001/api` for `langgraph_url` and `http://gateway:8001` for `gateway_url`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` / `DEER_FLOW_CHANNELS_GATEWAY_URL`.
- Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token), `dingtalk` (client_id, client_secret, optional `card_template_id` for AI Card streaming)
**User-owned channel connections** (`config.yaml` -> `channel_connections`):
- Disabled by default. It is a user-binding layer on top of the existing `channels.*` runtime config, not a replacement for provider bot credentials.
- No public IP, OAuth callback URL, or provider webhook route is required by the current implementation.
- Telegram uses a deep-link `/start <code>` flow over the existing long-polling worker. Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom use `/connect <code>` over their existing outbound channel workers.
- Frontend APIs: `GET /api/channels/providers`, `GET /api/channels/connections`, `POST /api/channels/{provider}/connect`, and `DELETE /api/channels/connections/{connection_id}`.
- Browser APIs remain protected by normal Gateway auth/CSRF. Provider messages arrive through the already-configured channel workers.
- Slack replies use the configured operator bot token from `channels.slack` unless a future provider-token flow stores per-connection credentials.
- Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom workers resolve incoming platform identities to connection records before reaching `ChannelManager`.
- See `backend/docs/IM_CHANNEL_CONNECTIONS.md` for provider setup and operational notes.
### Memory System (`packages/harness/deerflow/agents/memory/`)
@@ -429,6 +441,12 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
**Token counting** (`packages/harness/deerflow/agents/memory/prompt.py`):
- `_count_tokens` budgets the injection. In default `tiktoken` mode, the encoding is loaded lazily and cached.
- Failed tiktoken loads are cached with a timestamp. During the fixed cooldown (`_TIKTOKEN_RETRY_COOLDOWN_S`, 600s), callers fall back to char estimation immediately instead of re-triggering the blocking BPE download; after the cooldown, transient outages can self-heal without a restart.
- In-flight loads are cached as a LOADING sentinel so concurrent callers fall back instead of spawning more blocking threads.
- Set `memory.token_counting: char` to skip tiktoken entirely and use the network-free CJK-aware char estimate.
Focused regression coverage for the updater lives in `backend/tests/test_memory_updater.py`.
**Configuration** (`config.yaml``memory`):
@@ -438,6 +456,7 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_
- `model_name` - LLM for updates (null = default model)
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
- `max_injection_tokens` - Token limit for prompt injection (2000)
- `token_counting` - Token counting strategy for the injection budget: `tiktoken` (default, accurate but may download BPE data from a public endpoint on first use — can block for a long time in network-restricted environments, see issues #3402/#3429) or `char` (network-free CJK-aware char estimate, never touches tiktoken)
### Reflection System (`packages/harness/deerflow/reflection/`)
+1 -1
View File
@@ -69,7 +69,7 @@ Middlewares execute in strict order, each handling a specific concern:
Per-thread isolated execution with virtual path translation:
- **Abstract interface**: `execute_command`, `read_file`, `write_file`, `list_dir`
- **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/). Async runtime paths use async sandbox lifecycle hooks so startup, readiness polling, and release do not block the event loop.
- **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/). Async runtime paths use async sandbox lifecycle hooks so startup, readiness polling, and release do not block the event loop. `AioSandboxProvider` validates active-cache and warm-pool containers during acquire/reuse, dropping definitively dead entries so a thread can provision a fresh sandbox after an unexpected container exit while keeping `get()` as an in-memory lookup. Backend health-check failures are treated as unknown, not dead, and a container that cannot be verified during discovery is simply not adopted (acquire falls through to create instead of failing).
- **Virtual paths**: `/mnt/user-data/{workspace,uploads,outputs}` → thread-specific physical directories
- **Skills path**: `/mnt/skills``deer-flow/skills/` directory
- **Skills loading**: Recursively discovers nested `SKILL.md` files under `skills/{public,custom}` and preserves nested container paths
+11
View File
@@ -20,6 +20,17 @@ KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset(
)
def extract_connect_code(text: str) -> str | None:
"""Extract the one-time channel binding code from a connect command."""
parts = text.strip().split()
if len(parts) < 2:
return None
command = parts[0].lower()
if command in {"/connect", "connect"}:
return parts[1]
return None
def is_known_channel_command(text: str) -> bool:
"""Return whether text starts with a registered channel control command."""
if not text.startswith("/"):
@@ -0,0 +1,44 @@
"""Helpers for attaching persisted channel connection ownership to inbound messages."""
from __future__ import annotations
from typing import Any
from app.channels.message_bus import InboundMessage
async def attach_connection_identity(
inbound: InboundMessage,
*,
repo: Any,
provider: str,
workspace_id: str | None,
fallback_without_workspace: bool = False,
) -> InboundMessage:
"""Attach connection metadata to an inbound message when a persisted binding exists."""
if repo is None:
return inbound
workspace_candidates: list[str | None] = []
if workspace_id:
workspace_candidates.append(workspace_id)
if fallback_without_workspace:
workspace_candidates.append(None)
if not workspace_candidates:
return inbound
for candidate in workspace_candidates:
connection = await repo.find_connection_by_external_identity(
provider=provider,
external_account_id=inbound.user_id,
workspace_id=candidate,
)
if connection is None:
continue
inbound.connection_id = connection["id"]
inbound.owner_user_id = connection["owner_user_id"]
inbound.workspace_id = connection.get("workspace_id")
return inbound
return inbound
+105 -1
View File
@@ -14,7 +14,8 @@ from typing import Any
import httpx
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -136,6 +137,7 @@ class DingTalkChannel(Channel):
self._incoming_messages: dict[str, Any] = {}
self._incoming_messages_lock = threading.Lock()
self._card_repliers: dict[str, Any] = {}
self._connection_repo = config.get("connection_repo")
@property
def supports_streaming(self) -> bool:
@@ -395,6 +397,24 @@ class DingTalkChannel(Channel):
text[:100],
)
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(
self._bind_connection_from_connect_code(
conversation_type=conversation_type,
sender_staff_id=sender_staff_id,
sender_nick=sender_nick,
conversation_id=conversation_id,
code=connect_code,
),
self._main_loop,
)
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
else:
logger.warning("[DingTalk] main loop not running, cannot bind channel connection")
return
if _is_dingtalk_command(text):
msg_type = InboundMessageType.COMMAND
else:
@@ -450,11 +470,95 @@ class DingTalkChannel(Channel):
return ""
async def _prepare_inbound(self, chat_id: str, inbound: InboundMessage) -> None:
inbound = await self._attach_connection_identity(inbound)
# Running reply must finish before publish_inbound so AI card tracks are
# registered before the manager emits streaming outbounds.
await self._send_running_reply(chat_id, inbound)
await self.bus.publish_inbound(inbound)
@staticmethod
def _connection_workspace_id(conversation_type: str, conversation_id: str) -> str | None:
if conversation_type == _CONVERSATION_TYPE_GROUP and conversation_id:
return conversation_id
return None
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
conversation_type = str(inbound.metadata.get("conversation_type") or _CONVERSATION_TYPE_P2P)
conversation_id = str(inbound.metadata.get("conversation_id") or "")
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="dingtalk",
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
fallback_without_workspace=True,
)
async def _bind_connection_from_connect_code(
self,
*,
conversation_type: str,
sender_staff_id: str,
sender_nick: str,
conversation_id: str,
code: str,
) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="dingtalk", state=code)
if state is None:
await self._send_connection_reply(
conversation_type,
sender_staff_id,
conversation_id,
"DingTalk connection code is invalid or expired.",
)
return True
if not sender_staff_id:
await self._send_connection_reply(
conversation_type,
sender_staff_id,
conversation_id,
"DingTalk connection could not be completed from this message.",
)
return True
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="dingtalk",
external_account_id=sender_staff_id,
external_account_name=sender_nick or None,
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
metadata={
"conversation_type": conversation_type,
"conversation_id": conversation_id,
},
status="connected",
)
await self._send_connection_reply(
conversation_type,
sender_staff_id,
conversation_id,
"DingTalk connected to DeerFlow.",
)
return True
async def _send_connection_reply(
self,
conversation_type: str,
sender_staff_id: str,
conversation_id: str,
text: str,
) -> None:
robot_code = self._client_id
if conversation_type == _CONVERSATION_TYPE_GROUP:
if conversation_id:
await self._send_text_message_to_group(robot_code, conversation_id, text)
return
if sender_staff_id:
await self._send_text_message_to_user(robot_code, sender_staff_id, text)
async def _send_running_reply(self, chat_id: str, inbound: InboundMessage) -> None:
conversation_type = inbound.metadata.get("conversation_type", _CONVERSATION_TYPE_P2P)
sender_staff_id = inbound.metadata.get("sender_staff_id", "")
+64 -2
View File
@@ -10,8 +10,9 @@ from pathlib import Path
from typing import Any
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -70,6 +71,7 @@ class DiscordChannel(Channel):
self._discord_loop: asyncio.AbstractEventLoop | None = None
self._main_loop: asyncio.AbstractEventLoop | None = None
self._discord_module = None
self._connection_repo = config.get("connection_repo")
async def start(self) -> None:
if self._running:
@@ -287,6 +289,10 @@ class DiscordChannel(Channel):
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
# Don't return early if text is empty — still process the mention (e.g., create thread)
connect_code = extract_connect_code(text)
if connect_code and await self._bind_connection_from_connect_code(message, connect_code):
return
# --- Determine thread/channel routing and typing target ---
thread_id = None
chat_id = None
@@ -315,6 +321,7 @@ class DiscordChannel(Channel):
},
)
inbound.topic_id = thread_id
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
self._publish(inbound)
# Start typing indicator in the thread
if typing_target:
@@ -422,6 +429,7 @@ class DiscordChannel(Channel):
},
)
inbound.topic_id = thread_id
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
# Start typing indicator in the correct target (thread or channel)
if typing_target:
@@ -436,6 +444,60 @@ class DiscordChannel(Channel):
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
async def _attach_connection_identity(self, inbound: InboundMessage, guild_id: str | None = None) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="discord",
workspace_id=guild_id,
fallback_without_workspace=True,
)
async def _bind_connection_from_connect_code(self, message, code: str) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="discord", state=code)
if state is None:
await self._send_connection_reply(message, "Discord connection code is invalid or expired.")
return True
guild = getattr(message, "guild", None)
channel = getattr(message, "channel", None)
author = getattr(message, "author", None)
user_id = str(getattr(author, "id", "") or "")
if not user_id:
await self._send_connection_reply(message, "Discord connection could not be completed from this message.")
return True
guild_id = str(getattr(guild, "id", "") or "") or None
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="discord",
external_account_id=user_id,
external_account_name=getattr(author, "display_name", None) or getattr(author, "name", None),
workspace_id=guild_id,
workspace_name=getattr(guild, "name", None) if guild is not None else None,
metadata={
"guild_id": guild_id,
"channel_id": str(getattr(channel, "id", "") or ""),
},
status="connected",
)
await self._send_connection_reply(message, "Discord connected to DeerFlow.")
return True
@staticmethod
async def _send_connection_reply(message, text: str) -> None:
channel = getattr(message, "channel", None)
send = getattr(channel, "send", None)
if send is None:
return
try:
await send(text)
except Exception:
logger.exception("[Discord] failed to send connection reply")
def _run_client(self) -> None:
self._discord_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._discord_loop)
+78 -2
View File
@@ -11,7 +11,8 @@ import time
from typing import Any, Literal
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import (
PENDING_CLARIFICATION_METADATA_KEY,
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY,
@@ -71,6 +72,7 @@ class FeishuChannel(Channel):
self._CreateImageRequestBody = None
self._GetMessageResourceRequest = None
self._thread_lock = threading.Lock()
self._connection_repo = config.get("connection_repo")
@staticmethod
def _non_empty_str(value: Any) -> str | None:
@@ -86,6 +88,23 @@ class FeishuChannel(Channel):
def supports_streaming(self) -> bool:
return True
@property
def is_running(self) -> bool:
if not self._running:
return False
return self._thread is not None and self._thread.is_alive()
def _build_event_handler(self, lark):
return (
lark.EventDispatcherHandler.builder("", "")
.register_p2_im_message_receive_v1(self._on_message)
.register_p2_im_message_message_read_v1(self._on_ignored_message_event)
.register_p2_im_message_reaction_created_v1(self._on_ignored_message_event)
.register_p2_im_message_reaction_deleted_v1(self._on_ignored_message_event)
.register_p2_im_message_recalled_v1(self._on_ignored_message_event)
.build()
)
async def start(self) -> None:
if self._running:
return
@@ -179,7 +198,7 @@ class FeishuChannel(Channel):
# thread's uvloop.
_ws_client_mod.loop = loop
event_handler = lark.EventDispatcherHandler.builder("", "").register_p2_im_message_receive_v1(self._on_message).build()
event_handler = self._build_event_handler(lark)
ws_client = lark.ws.Client(
app_id=app_id,
app_secret=app_secret,
@@ -191,6 +210,10 @@ class FeishuChannel(Channel):
except Exception:
if self._running:
logger.exception("Feishu WebSocket error")
self._running = False
def _on_ignored_message_event(self, event) -> None:
logger.debug("[Feishu] ignoring non-content message event: %s", type(event).__name__)
async def stop(self) -> None:
self._running = False
@@ -726,11 +749,47 @@ class FeishuChannel(Channel):
async def _prepare_inbound(self, msg_id: str, inbound) -> None:
"""Kick off Feishu side effects without delaying inbound dispatch."""
inbound = await self._attach_connection_identity(inbound)
reaction_task = asyncio.create_task(self._add_reaction(msg_id, "OK"))
self._track_background_task(reaction_task, name="add_reaction", msg_id=msg_id)
self._ensure_running_card_started(msg_id)
await self.bus.publish_inbound(inbound)
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="feishu",
workspace_id=inbound.chat_id,
)
async def _bind_connection_from_connect_code(self, *, message_id: str, chat_id: str, user_id: str, code: str) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="feishu", state=code)
if state is None:
await self._reply_card(message_id, "Feishu connection code is invalid or expired.")
return True
if not user_id or not chat_id:
await self._reply_card(message_id, "Feishu connection could not be completed from this message.")
return True
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="feishu",
external_account_id=user_id,
workspace_id=chat_id,
metadata={
"chat_id": chat_id,
"message_id": message_id,
},
status="connected",
)
await self._reply_card(message_id, "Feishu connected to DeerFlow.")
return True
def _on_message(self, event) -> None:
"""Called by lark-oapi when a message is received (runs in lark thread)."""
try:
@@ -819,6 +878,23 @@ class FeishuChannel(Channel):
logger.info("[Feishu] empty text, ignoring message")
return
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(
self._bind_connection_from_connect_code(
message_id=msg_id,
chat_id=chat_id,
user_id=sender_id,
code=connect_code,
),
self._main_loop,
)
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
else:
logger.warning("[Feishu] main loop not running, cannot bind channel connection")
return
# Only treat known slash commands as commands; absolute paths and
# other slash-prefixed text should be handled as normal chat.
if _is_feishu_command(text):
+142 -30
View File
@@ -274,6 +274,22 @@ def _response_metadata(base_metadata: dict[str, Any], *, pending_clarification:
return metadata
def _thread_channel_metadata(msg: InboundMessage) -> dict[str, Any]:
channel_source: dict[str, Any] = {
"type": "im_channel",
"provider": msg.channel_name,
"chat_id": msg.chat_id,
}
if msg.topic_id:
channel_source["topic_id"] = msg.topic_id
if msg.thread_ts:
channel_source["thread_ts"] = msg.thread_ts
if msg.connection_id:
channel_source["connection_id"] = msg.connection_id
return {"channel_source": channel_source}
def _extract_text_content(content: Any) -> str:
"""Extract text from a streaming payload content field."""
if isinstance(content, str):
@@ -440,6 +456,43 @@ def _human_input_message(content: str, *, original_content: str | None = None) -
return message
def _auth_disabled_owner_user_id() -> str | None:
try:
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled
except Exception:
logger.debug("Unable to inspect auth-disabled mode for channel owner fallback", exc_info=True)
return None
return AUTH_DISABLED_USER_ID if is_auth_disabled() else None
def _effective_owner_user_id(msg: InboundMessage) -> str | None:
return _auth_disabled_owner_user_id() or msg.owner_user_id
def _apply_effective_owner(msg: InboundMessage) -> InboundMessage:
owner_user_id = _effective_owner_user_id(msg)
if owner_user_id:
msg.owner_user_id = owner_user_id
return msg
def _owner_headers(msg: InboundMessage) -> dict[str, str] | None:
owner_user_id = _effective_owner_user_id(msg)
if not owner_user_id:
return None
return create_internal_auth_headers(owner_user_id=owner_user_id)
def _safe_user_id_for_run(raw_user_id: str) -> str:
from deerflow.config.paths import get_paths
try:
return get_paths().prepare_user_dir_for_raw_id(raw_user_id)
except Exception:
logger.exception("Failed to prepare channel run user directory")
return make_safe_user_id(raw_user_id)
def _resolve_slash_skill_command(
text: str,
available_skills: set[str] | None = None,
@@ -670,6 +723,7 @@ class ChannelManager:
assistant_id: str = DEFAULT_ASSISTANT_ID,
default_session: dict[str, Any] | None = None,
channel_sessions: dict[str, Any] | None = None,
connection_repo: Any | None = None,
) -> None:
self.bus = bus
self.store = store
@@ -679,6 +733,7 @@ class ChannelManager:
self._assistant_id = assistant_id
self._default_session = _as_dict(default_session)
self._channel_sessions = dict(channel_sessions or {})
self._connection_repo = connection_repo
self._client = None # lazy init — langgraph_sdk async client
self._skill_storage: SkillStorage | None = None
self._csrf_token = generate_csrf_token()
@@ -728,12 +783,17 @@ class ChannelManager:
configurable["checkpoint_ns"] = ""
configurable["thread_id"] = thread_id
# ``user_id`` drives user-scoped filesystem buckets that only accept
# ``[A-Za-z0-9_-]``, so normalize the channel id and keep the raw value
# under ``channel_user_id`` for platform-facing lookups.
# ``user_id`` drives DeerFlow-owned memory, files, and thread buckets.
# For browser-connected IM channels, prefer the DeerFlow account that
# owns the connection. Preserve the raw platform user under
# ``channel_user_id`` for platform-facing lookups and audits.
run_context_identity: dict[str, Any] = {"thread_id": thread_id}
owner_user_id = _effective_owner_user_id(msg)
if owner_user_id:
run_context_identity["user_id"] = _safe_user_id_for_run(owner_user_id)
elif msg.user_id:
run_context_identity["user_id"] = _safe_user_id_for_run(msg.user_id)
if msg.user_id:
run_context_identity["user_id"] = make_safe_user_id(msg.user_id)
run_context_identity["channel_user_id"] = msg.user_id
run_context = _merge_dicts(
@@ -845,6 +905,7 @@ class ChannelManager:
logger.error("[Manager] unhandled error in message task: %s", exc, exc_info=exc)
async def _handle_message(self, msg: InboundMessage) -> None:
msg = _apply_effective_owner(msg)
async with self._semaphore:
try:
if msg.msg_type == InboundMessageType.COMMAND:
@@ -877,10 +938,27 @@ class ChannelManager:
# -- chat handling -----------------------------------------------------
async def _create_thread(self, client, msg: InboundMessage) -> str:
"""Create a new thread through Gateway and store the mapping."""
thread = await client.threads.create()
thread_id = thread["thread_id"]
async def _lookup_thread_id(self, msg: InboundMessage) -> str | None:
if msg.connection_id and self._connection_repo is not None:
return await self._connection_repo.get_thread_id(
msg.connection_id,
msg.chat_id,
msg.topic_id,
)
return self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
async def _store_thread_id(self, msg: InboundMessage, thread_id: str) -> None:
if msg.connection_id and msg.owner_user_id and self._connection_repo is not None:
await self._connection_repo.set_thread_id(
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
provider=msg.channel_name,
external_conversation_id=msg.chat_id,
external_topic_id=msg.topic_id,
thread_id=thread_id,
)
return
self.store.set_thread_id(
msg.channel_name,
msg.chat_id,
@@ -888,18 +966,40 @@ class ChannelManager:
topic_id=msg.topic_id,
user_id=msg.user_id,
)
async def _create_thread(self, client, msg: InboundMessage) -> str:
"""Create a new thread through Gateway and store the mapping."""
metadata = _thread_channel_metadata(msg)
owner_headers = _owner_headers(msg)
if owner_headers:
thread = await client.threads.create(metadata=metadata, headers=owner_headers)
else:
thread = await client.threads.create(metadata=metadata)
thread_id = thread["thread_id"]
await self._store_thread_id(msg, thread_id)
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
return thread_id
async def _update_thread_channel_metadata(self, client, msg: InboundMessage, thread_id: str) -> None:
"""Best-effort source metadata backfill for existing IM-created threads."""
update_kwargs: dict[str, Any] = {"metadata": _thread_channel_metadata(msg)}
if owner_headers := _owner_headers(msg):
update_kwargs["headers"] = owner_headers
try:
await client.threads.update(thread_id, **update_kwargs)
except Exception:
logger.debug("[Manager] failed to update channel metadata for thread_id=%s", thread_id, exc_info=True)
async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None:
client = self._get_client()
# Look up existing DeerFlow thread.
# topic_id may be None (e.g. Telegram private chats) — the store
# handles this by using the "channel:chat_id" key without a topic suffix.
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
thread_id = await self._lookup_thread_id(msg)
if thread_id:
logger.info("[Manager] reusing thread: thread_id=%s for topic_id=%s", thread_id, msg.topic_id)
await self._update_thread_channel_metadata(client, msg, thread_id)
# No existing thread found — create a new one
if thread_id is None:
@@ -940,14 +1040,19 @@ class ChannelManager:
return
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
run_kwargs: dict[str, Any] = {
"input": {"messages": [human_message]},
"config": run_config,
"context": run_context,
"multitask_strategy": "reject",
}
if owner_headers := _owner_headers(msg):
run_kwargs["headers"] = owner_headers
try:
result = await client.runs.wait(
thread_id,
assistant_id,
input={"messages": [human_message]},
config=run_config,
context=run_context,
multitask_strategy="reject",
**run_kwargs,
)
except Exception as exc:
if _is_thread_busy_error(exc):
@@ -984,6 +1089,8 @@ class ChannelManager:
artifacts=artifacts,
attachments=attachments,
thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
)
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
@@ -1008,16 +1115,21 @@ class ChannelManager:
last_published_text = ""
last_publish_at = 0.0
stream_error: BaseException | None = None
stream_kwargs: dict[str, Any] = {
"input": {"messages": [human_message]},
"config": run_config,
"context": run_context,
"stream_mode": ["messages-tuple", "values"],
"multitask_strategy": "reject",
}
if owner_headers := _owner_headers(msg):
stream_kwargs["headers"] = owner_headers
try:
async for chunk in client.runs.stream(
thread_id,
assistant_id,
input={"messages": [human_message]},
config=run_config,
context=run_context,
stream_mode=["messages-tuple", "values"],
multitask_strategy="reject",
**stream_kwargs,
):
event = getattr(chunk, "event", "")
data = getattr(chunk, "data", None)
@@ -1047,6 +1159,8 @@ class ChannelManager:
text=latest_text,
is_final=False,
thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_response_metadata(msg.metadata),
)
)
@@ -1093,6 +1207,8 @@ class ChannelManager:
attachments=attachments,
is_final=True,
thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
)
)
@@ -1124,18 +1240,10 @@ class ChannelManager:
if reply is None and command == "new":
# Create a new thread through Gateway
client = self._get_client()
thread = await client.threads.create()
new_thread_id = thread["thread_id"]
self.store.set_thread_id(
msg.channel_name,
msg.chat_id,
new_thread_id,
topic_id=msg.topic_id,
user_id=msg.user_id,
)
await self._create_thread(client, msg)
reply = "New conversation started."
elif reply is None and command == "status":
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
thread_id = await self._lookup_thread_id(msg)
reply = f"Active thread: {thread_id}" if thread_id else "No active conversation."
elif reply is None and command == "models":
reply = await self._fetch_gateway("/api/models", "models")
@@ -1174,9 +1282,11 @@ class ChannelManager:
outbound = OutboundMessage(
channel_name=msg.channel_name,
chat_id=msg.chat_id,
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "",
thread_id=await self._lookup_thread_id(msg) or "",
text=reply,
thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_slim_metadata(msg.metadata),
)
await self.bus.publish_outbound(outbound)
@@ -1212,9 +1322,11 @@ class ChannelManager:
outbound = OutboundMessage(
channel_name=msg.channel_name,
chat_id=msg.chat_id,
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "",
thread_id=await self._lookup_thread_id(msg) or "",
text=error_text,
thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_slim_metadata(msg.metadata),
)
await self.bus.publish_outbound(outbound)
+14
View File
@@ -44,6 +44,12 @@ class InboundMessage:
Messages sharing the same ``topic_id`` within a ``chat_id`` will
reuse the same DeerFlow thread. When ``None``, each message
creates a new thread (one-shot Q&A).
connection_id: Optional DeerFlow channel connection id. When present,
conversation mapping is scoped by the connection instead of the
legacy global ``channel_name:chat_id[:topic_id]`` key.
owner_user_id: DeerFlow user id that owns the channel connection.
Platform user ids stay in ``user_id``.
workspace_id: Optional external workspace/guild/team id.
files: Optional list of file attachments (platform-specific dicts).
metadata: Arbitrary extra data from the channel.
created_at: Unix timestamp when the message was created.
@@ -56,6 +62,9 @@ class InboundMessage:
msg_type: InboundMessageType = InboundMessageType.CHAT
thread_ts: str | None = None
topic_id: str | None = None
connection_id: str | None = None
owner_user_id: str | None = None
workspace_id: str | None = None
files: list[dict[str, Any]] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)
created_at: float = field(default_factory=time.time)
@@ -95,6 +104,9 @@ class OutboundMessage:
is_final: Whether this is the final message in the response stream.
thread_ts: Optional platform thread identifier for threaded replies.
metadata: Arbitrary extra data.
connection_id: Optional DeerFlow channel connection id used for
connection-specific outbound credentials.
owner_user_id: DeerFlow user id that owns the channel connection.
created_at: Unix timestamp.
"""
@@ -106,6 +118,8 @@ class OutboundMessage:
attachments: list[ResolvedAttachment] = field(default_factory=list)
is_final: bool = True
thread_ts: str | None = None
connection_id: str | None = None
owner_user_id: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
created_at: float = field(default_factory=time.time)
@@ -0,0 +1,154 @@
"""Local persistence for runtime IM channel configuration."""
from __future__ import annotations
import json
import logging
import tempfile
import threading
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
RUNTIME_CHANNEL_DISABLED_FLAG = "_runtime_disabled"
class ChannelRuntimeConfigStore:
"""JSON-backed store for channel credentials entered from the UI.
This intentionally mirrors ``ChannelStore``: local/private deployments get
durable runtime configuration without needing a public callback URL or a
config.yaml edit.
"""
def __init__(self, path: str | Path | None = None) -> None:
if path is None:
from deerflow.config.paths import get_paths
path = Path(get_paths().base_dir) / "channels" / "runtime-config.json"
self._path = Path(path)
self._path.parent.mkdir(parents=True, exist_ok=True)
self._data: dict[str, dict[str, Any]] = self._load()
self._lock = threading.Lock()
def _load(self) -> dict[str, dict[str, Any]]:
if self._path.exists():
try:
raw = json.loads(self._path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
logger.warning("Corrupt channel runtime config store at %s, starting fresh", self._path)
return {}
if isinstance(raw, dict):
return {str(name): dict(value) for name, value in raw.items() if isinstance(value, dict)}
return {}
def _save(self) -> None:
fd = tempfile.NamedTemporaryFile(
mode="w",
dir=self._path.parent,
suffix=".tmp",
delete=False,
)
try:
json.dump(self._data, fd, indent=2, ensure_ascii=False)
fd.close()
Path(fd.name).replace(self._path)
try:
self._path.chmod(0o600)
except OSError:
logger.debug("Unable to chmod channel runtime config store at %s", self._path, exc_info=True)
except BaseException:
fd.close()
Path(fd.name).unlink(missing_ok=True)
raise
def load_all(self) -> dict[str, dict[str, Any]]:
with self._lock:
return {name: dict(config) for name, config in self._data.items()}
def get_provider_config(self, provider: str) -> dict[str, Any] | None:
with self._lock:
config = self._data.get(provider)
return dict(config) if isinstance(config, dict) else None
def set_provider_config(self, provider: str, config: dict[str, Any]) -> None:
with self._lock:
self._data[provider] = dict(config)
self._save()
def set_provider_disconnected(self, provider: str) -> None:
with self._lock:
self._data[provider] = {
"enabled": False,
RUNTIME_CHANNEL_DISABLED_FLAG: True,
}
self._save()
def remove_provider_config(self, provider: str) -> bool:
with self._lock:
if provider not in self._data:
return False
del self._data[provider]
self._save()
return True
def _provider_enabled(channel_connections_config: Any, provider: str) -> bool:
provider_config = getattr(channel_connections_config, provider, None)
return bool(getattr(provider_config, "enabled", False))
def _runtime_channel_disconnected(runtime_config: dict[str, Any]) -> bool:
return runtime_config.get(RUNTIME_CHANNEL_DISABLED_FLAG) is True and runtime_config.get("enabled") is False
def merge_runtime_channel_configs(
channels_config: dict[str, Any],
channel_connections_config: Any,
*,
store: ChannelRuntimeConfigStore | None = None,
) -> None:
"""Merge persisted runtime provider config into ``channels_config`` in-place."""
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
return
runtime_store = store or ChannelRuntimeConfigStore()
for provider, runtime_config in runtime_store.load_all().items():
if not _provider_enabled(channel_connections_config, provider):
continue
if _runtime_channel_disconnected(runtime_config):
channels_config.pop(provider, None)
continue
existing = channels_config.get(provider)
merged = dict(runtime_config)
if isinstance(existing, dict):
merged.update(existing)
channels_config[provider] = merged
def apply_runtime_connection_config(
channel_connections_config: Any,
*,
store: ChannelRuntimeConfigStore | None = None,
) -> Any:
"""Apply persisted connection metadata that lives outside ``channels``.
Telegram uses a bot username for deep links; UI-entered values are stored
with the runtime channel config so local restarts keep the provider
configured.
"""
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
return channel_connections_config
runtime_store = store or ChannelRuntimeConfigStore()
telegram_runtime_config = runtime_store.get_provider_config("telegram")
bot_username = ""
if isinstance(telegram_runtime_config, dict):
bot_username = str(telegram_runtime_config.get("bot_username") or "").strip()
if not bot_username or not _provider_enabled(channel_connections_config, "telegram"):
return channel_connections_config
config = channel_connections_config.model_copy(deep=True)
config.telegram.bot_username = bot_username
return config
+121 -22
View File
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any
from app.channels.base import Channel
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
from app.channels.message_bus import MessageBus
from app.channels.runtime_config_store import merge_runtime_channel_configs
from app.channels.store import ChannelStore
logger = logging.getLogger(__name__)
@@ -42,6 +43,11 @@ _CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
def _channel_has_credentials(name: str, channel_config: dict[str, Any]) -> bool:
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
return any(not isinstance(channel_config.get(key), bool) and channel_config.get(key) is not None and str(channel_config[key]).strip() for key in cred_keys)
def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str, default: str) -> str:
value = config.pop(config_key, None)
if isinstance(value, str) and value.strip():
@@ -52,6 +58,30 @@ def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str,
return default
def _merge_channel_connection_runtime_config(channels_config: dict[str, Any], app_config: AppConfig) -> None:
connection_config = getattr(app_config, "channel_connections", None)
merge_runtime_channel_configs(channels_config, connection_config)
def _make_connection_repo(app_config: AppConfig):
connection_config = getattr(app_config, "channel_connections", None)
if connection_config is None or not getattr(connection_config, "enabled", False):
return None
try:
from deerflow.persistence.channel_connections import ChannelConnectionRepository
from deerflow.persistence.engine import get_session_factory
except Exception:
logger.exception("Failed to import channel connection repository")
return None
session_factory = get_session_factory()
if session_factory is None:
logger.warning("Channel connections are enabled but database persistence is not available")
return None
return ChannelConnectionRepository(session_factory)
class ChannelService:
"""Manages the lifecycle of all configured IM channels.
@@ -59,9 +89,10 @@ class ChannelService:
instantiates enabled channels, and starts the ChannelManager dispatcher.
"""
def __init__(self, channels_config: dict[str, Any] | None = None) -> None:
def __init__(self, channels_config: dict[str, Any] | None = None, *, connection_repo: Any | None = None) -> None:
self.bus = MessageBus()
self.store = ChannelStore()
self._connection_repo = connection_repo
config = dict(channels_config or {})
langgraph_url = _resolve_service_url(config, "langgraph_url", _CHANNELS_LANGGRAPH_URL_ENV, DEFAULT_LANGGRAPH_URL)
gateway_url = _resolve_service_url(config, "gateway_url", _CHANNELS_GATEWAY_URL_ENV, DEFAULT_GATEWAY_URL)
@@ -74,6 +105,7 @@ class ChannelService:
gateway_url=gateway_url,
default_session=default_session if isinstance(default_session, dict) else None,
channel_sessions=channel_sessions,
connection_repo=connection_repo,
)
self._channels: dict[str, Any] = {} # name -> Channel instance
self._config = config
@@ -90,8 +122,9 @@ class ChannelService:
# extra fields are allowed by AppConfig (extra="allow")
extra = app_config.model_extra or {}
if "channels" in extra:
channels_config = extra["channels"]
return cls(channels_config=channels_config)
channels_config = dict(extra["channels"] or {})
_merge_channel_connection_runtime_config(channels_config, app_config)
return cls(channels_config=channels_config, connection_repo=_make_connection_repo(app_config))
async def start(self) -> None:
"""Start the manager and all enabled channels."""
@@ -99,36 +132,79 @@ class ChannelService:
return
await self.manager.start()
self._running = True
ready_status = await self.ensure_ready_channels(attempts=2)
ready_count = sum(1 for ready in ready_status.values() if ready)
logger.info("ChannelService started with %d/%d ready channels", ready_count, len(ready_status))
async def ensure_ready_channels(self, *, attempts: int = 1) -> dict[str, bool]:
"""Start or restart enabled configured channels that are not ready."""
ready_status: dict[str, bool] = {}
for name, channel_config in self._config.items():
if not isinstance(channel_config, dict):
continue
if not channel_config.get("enabled", False):
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
has_creds = any(not isinstance(channel_config.get(k), bool) and channel_config.get(k) is not None and str(channel_config[k]).strip() for k in cred_keys)
if has_creds:
if _channel_has_credentials(name, channel_config):
logger.warning(
"Channel '%s' has credentials configured but is disabled. Set enabled: true under channels.%s in config.yaml to activate it.",
name,
name,
"A configured channel has credentials configured but is disabled. Set enabled: true under its channels entry in config.yaml to activate it.",
)
else:
logger.info("Channel %s is disabled, skipping", name)
logger.info("A configured channel is disabled, skipping")
continue
await self._start_channel(name, channel_config)
ready_status[name] = await self.ensure_channel_ready(name, attempts=attempts)
return ready_status
self._running = True
logger.info("ChannelService started with channels: %s", list(self._channels.keys()))
async def ensure_channel_ready(
self,
name: str,
config: dict[str, Any] | None = None,
*,
attempts: int = 1,
) -> bool:
"""Ensure a single enabled channel is running using its current config."""
if not self._running:
logger.warning("ChannelService is not running; cannot ensure channel readiness")
return False
if config is not None:
self._config[name] = dict(config)
channel_config = self._config.get(name)
if not channel_config or not isinstance(channel_config, dict):
logger.warning("No config for requested channel")
return False
if not channel_config.get("enabled", False):
return False
channel = self._channels.get(name)
if channel is not None and channel.is_running:
return True
if channel is not None:
try:
await channel.stop()
except Exception:
logger.exception("Error stopping non-running channel before readiness retry")
self._channels.pop(name, None)
max_attempts = max(1, attempts)
for attempt in range(max_attempts):
if attempt > 0:
logger.info("Retrying channel startup after readiness check")
if await self._start_channel(name, channel_config):
return True
return False
async def stop(self) -> None:
"""Stop all channels and the manager."""
for name, channel in list(self._channels.items()):
try:
await channel.stop()
logger.info("Channel %s stopped", name)
logger.info("Channel stopped")
except Exception:
logger.exception("Error stopping channel %s", name)
logger.exception("Error stopping channel")
self._channels.clear()
await self.manager.stop()
@@ -141,21 +217,42 @@ class ChannelService:
try:
await self._channels[name].stop()
except Exception:
logger.exception("Error stopping channel %s for restart", name)
logger.exception("Error stopping channel for restart")
del self._channels[name]
config = self._config.get(name)
if not config or not isinstance(config, dict):
logger.warning("No config for channel %s", name)
logger.warning("No config for requested channel")
return False
return await self._start_channel(name, config)
async def configure_channel(self, name: str, config: dict[str, Any]) -> bool:
"""Apply runtime config for a channel and restart it if the service is running."""
self._config[name] = dict(config)
if not self._running:
return True
return await self.restart_channel(name)
async def remove_channel(self, name: str) -> bool:
"""Remove runtime config for a channel and stop it if currently running."""
self._config.pop(name, None)
channel = self._channels.pop(name, None)
if channel is None:
return True
try:
await channel.stop()
logger.info("Channel stopped and removed")
return True
except Exception:
logger.exception("Error stopping channel for removal")
return False
async def _start_channel(self, name: str, config: dict[str, Any]) -> bool:
"""Instantiate and start a single channel."""
import_path = _CHANNEL_REGISTRY.get(name)
if not import_path:
logger.warning("Unknown channel type: %s", name)
logger.warning("Unknown channel type")
return False
try:
@@ -163,24 +260,26 @@ class ChannelService:
channel_cls = resolve_class(import_path, base_class=None)
except Exception:
logger.exception("Failed to import channel class for %s", name)
logger.exception("Failed to import channel class")
return False
try:
config = dict(config)
config["channel_store"] = self.store
if self._connection_repo is not None:
config["connection_repo"] = self._connection_repo
channel = channel_cls(bus=self.bus, config=config)
self._channels[name] = channel
await channel.start()
if not channel.is_running:
self._channels.pop(name, None)
logger.error("Channel %s did not enter a running state after start()", name)
logger.error("Channel did not enter a running state after start()")
return False
logger.info("Channel %s started", name)
logger.info("Channel started")
return True
except Exception:
self._channels.pop(name, None)
logger.exception("Failed to start channel %s", name)
logger.exception("Failed to start channel")
return False
def get_status(self) -> dict[str, Any]:
+139 -25
View File
@@ -9,7 +9,8 @@ from typing import Any
from markdown_to_mrkdwn import SlackMarkdownConverter
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -64,6 +65,8 @@ class SlackChannel(Channel):
self._web_client = None
self._loop: asyncio.AbstractEventLoop | None = None
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
self._connection_repo = config.get("connection_repo")
self._web_client_factory = config.get("web_client_factory")
configured_bot_user_id = config.get("bot_user_id")
self._bot_user_id = str(configured_bot_user_id).lstrip("@") if configured_bot_user_id else None
@@ -80,26 +83,28 @@ class SlackChannel(Channel):
return
self._SocketModeResponse = SocketModeResponse
if self._web_client_factory is None:
self._web_client_factory = WebClient
bot_token = self.config.get("bot_token", "")
app_token = self.config.get("app_token", "")
if self._connection_repo is not None and self.config.get("event_delivery") == "http":
if not bot_token:
logger.error("Slack HTTP Events mode requires bot_token")
return
await self._initialize_operator_web_client(str(bot_token))
self._loop = asyncio.get_event_loop()
self._running = True
self.bus.subscribe_outbound(self._on_outbound)
logger.info("Slack channel started in HTTP Events mode")
return
if not bot_token or not app_token:
logger.error("Slack channel requires bot_token and app_token")
return
self._web_client = WebClient(token=bot_token)
if self._bot_user_id is None:
try:
auth_info = await asyncio.to_thread(self._web_client.auth_test)
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
if user_id is None:
auth_get = getattr(auth_info, "get", None)
user_id = auth_get("user_id") if callable(auth_get) else None
if isinstance(user_id, str) and user_id:
self._bot_user_id = user_id
except Exception:
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
await self._initialize_operator_web_client(str(bot_token))
self._socket_client = SocketModeClient(
app_token=app_token,
web_client=self._web_client,
@@ -124,7 +129,8 @@ class SlackChannel(Channel):
logger.info("Slack channel stopped")
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
if not self._web_client:
web_client = await self._get_web_client_for_message(msg)
if not web_client:
return
kwargs: dict[str, Any] = {
@@ -137,11 +143,12 @@ class SlackChannel(Channel):
last_exc: Exception | None = None
for attempt in range(_max_retries):
try:
await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs)
await asyncio.to_thread(web_client.chat_postMessage, **kwargs)
# Add a completion reaction to the thread root
if msg.thread_ts:
await asyncio.to_thread(
self._add_reaction,
self._add_reaction_with_client,
web_client,
msg.chat_id,
msg.thread_ts,
"white_check_mark",
@@ -165,7 +172,8 @@ class SlackChannel(Channel):
if msg.thread_ts:
try:
await asyncio.to_thread(
self._add_reaction,
self._add_reaction_with_client,
web_client,
msg.chat_id,
msg.thread_ts,
"x",
@@ -177,7 +185,8 @@ class SlackChannel(Channel):
raise last_exc
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
if not self._web_client:
web_client = await self._get_web_client_for_message(msg)
if not web_client:
return False
try:
@@ -190,7 +199,7 @@ class SlackChannel(Channel):
if msg.thread_ts:
kwargs["thread_ts"] = msg.thread_ts
await asyncio.to_thread(self._web_client.files_upload_v2, **kwargs)
await asyncio.to_thread(web_client.files_upload_v2, **kwargs)
logger.info("[Slack] file uploaded: %s to channel=%s", attachment.filename, msg.chat_id)
return True
except Exception:
@@ -199,12 +208,38 @@ class SlackChannel(Channel):
# -- internal ----------------------------------------------------------
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
if not self._web_client:
async def _initialize_operator_web_client(self, bot_token: str) -> None:
self._web_client = self._web_client_factory(token=bot_token)
if self._bot_user_id is not None:
return
try:
self._web_client.reactions_add(
auth_info = await asyncio.to_thread(self._web_client.auth_test)
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
if user_id is None:
auth_get = getattr(auth_info, "get", None)
user_id = auth_get("user_id") if callable(auth_get) else None
if isinstance(user_id, str) and user_id:
self._bot_user_id = user_id
except Exception:
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
async def _get_web_client_for_message(self, msg: OutboundMessage):
if msg.connection_id and self._connection_repo is not None:
credentials = await self._connection_repo.get_credentials(msg.connection_id)
access_token = credentials.get("access_token") if credentials else None
if not access_token:
return self._web_client
if self._web_client_factory is None:
from slack_sdk import WebClient
self._web_client_factory = WebClient
return self._web_client_factory(token=access_token)
return self._web_client
@staticmethod
def _add_reaction_with_client(web_client, channel_id: str, timestamp: str, emoji: str) -> None:
try:
web_client.reactions_add(
channel=channel_id,
timestamp=timestamp,
name=emoji,
@@ -213,6 +248,12 @@ class SlackChannel(Channel):
if "already_reacted" not in str(exc):
logger.warning("[Slack] failed to add reaction %s: %s", emoji, exc)
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
if not self._web_client:
return
self._add_reaction_with_client(self._web_client, channel_id, timestamp, emoji)
def _send_running_reply(self, channel_id: str, thread_ts: str) -> None:
"""Send a 'Working on it......' reply in the thread (called from SDK thread)."""
if not self._web_client:
@@ -249,12 +290,15 @@ class SlackChannel(Channel):
# Handle message events (DM or @mention)
if etype in ("message", "app_mention"):
self._handle_message_event(event)
self._handle_message_event(
event,
team_id=req.payload.get("team_id") or req.payload.get("team") or event.get("team"),
)
except Exception:
logger.exception("Error processing Slack event")
def _handle_message_event(self, event: dict) -> None:
def _handle_message_event(self, event: dict, *, team_id: str | None = None) -> None:
# Ignore bot messages
if event.get("bot_id") or event.get("subtype"):
return
@@ -272,6 +316,19 @@ class SlackChannel(Channel):
if not text:
return
connect_code = extract_connect_code(text)
if connect_code:
if self._loop and self._loop.is_running():
asyncio.run_coroutine_threadsafe(
self._bind_connection_from_connect_code(
event=event,
team_id=str(team_id or event.get("team") or ""),
code=connect_code,
),
self._loop,
)
return
channel_id = event.get("channel", "")
thread_ts = event.get("thread_ts") or event.get("ts", "")
@@ -297,4 +354,61 @@ class SlackChannel(Channel):
self._add_reaction(channel_id, event.get("ts", thread_ts), "eyes")
# Send "running" reply first (fire-and-forget from SDK thread)
self._send_running_reply(channel_id, thread_ts)
if self._connection_repo is None:
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
else:
asyncio.run_coroutine_threadsafe(self._publish_inbound_with_connection(inbound, team_id=team_id), self._loop)
async def _publish_inbound_with_connection(self, inbound, *, team_id: str | None = None) -> None:
inbound = await self._attach_connection_identity(inbound, team_id=team_id)
await self.bus.publish_inbound(inbound)
async def _attach_connection_identity(self, inbound, *, team_id: str | None = None):
workspace_id = str(team_id or inbound.metadata.get("team_id") or "")
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="slack",
workspace_id=workspace_id,
)
async def _bind_connection_from_connect_code(self, *, event: dict, team_id: str, code: str) -> bool:
if self._connection_repo is None or not code:
return False
channel_id = str(event.get("channel") or "")
thread_ts = str(event.get("thread_ts") or event.get("ts") or "")
state = await self._connection_repo.consume_oauth_state(provider="slack", state=code)
if state is None:
self._post_connection_reply(channel_id, "Slack connection code is invalid or expired.", thread_ts)
return True
user_id = str(event.get("user") or "")
if not user_id or not team_id:
self._post_connection_reply(channel_id, "Slack connection could not be completed from this message.", thread_ts)
return True
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="slack",
external_account_id=user_id,
workspace_id=team_id,
metadata={
"team_id": team_id,
"channel_id": channel_id,
},
status="connected",
)
self._post_connection_reply(channel_id, "Slack connected to DeerFlow.", thread_ts)
return True
def _post_connection_reply(self, channel_id: str, text: str, thread_ts: str | None = None) -> None:
if not self._web_client or not channel_id:
return
kwargs: dict[str, Any] = {"channel": channel_id, "text": text}
if thread_ts:
kwargs["thread_ts"] = thread_ts
try:
self._web_client.chat_postMessage(**kwargs)
except Exception:
logger.exception("[Slack] failed to send connection reply in channel=%s", channel_id)
+77
View File
@@ -8,6 +8,7 @@ import threading
from typing import Any
from app.channels.base import Channel
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -35,6 +36,7 @@ class TelegramChannel(Channel):
pass
# chat_id -> last sent message_id for threaded replies
self._last_bot_message: dict[str, int] = {}
self._connection_repo = config.get("connection_repo")
async def start(self) -> None:
if self._running:
@@ -176,6 +178,26 @@ class TelegramChannel(Channel):
logger.exception("[Telegram] failed to send file: %s", attachment.filename)
return False
async def process_webhook_update(self, payload: dict[str, Any]) -> bool:
if not self._application:
return False
try:
from telegram import Update
except ImportError:
logger.error("python-telegram-bot is not installed. Install it with: uv add python-telegram-bot")
return False
update = Update.de_json(payload, self._application.bot)
if update is None:
return False
if self._tg_loop and self._tg_loop.is_running():
future = asyncio.run_coroutine_threadsafe(self._application.process_update(update), self._tg_loop)
await asyncio.wrap_future(future)
else:
await self._application.process_update(update)
return True
# -- helpers -----------------------------------------------------------
async def _send_running_reply(self, chat_id: str, reply_to_message_id: int) -> None:
@@ -233,6 +255,54 @@ class TelegramChannel(Channel):
return True
return user_id in self._allowed_users
@staticmethod
def _telegram_display_name(user) -> str:
full_name = getattr(user, "full_name", None)
if isinstance(full_name, str) and full_name:
return full_name
username = getattr(user, "username", None)
if isinstance(username, str) and username:
return username
return str(getattr(user, "id", ""))
async def _bind_connection_from_start_token(self, update, state_token: str) -> bool:
if self._connection_repo is None or not state_token:
return False
state = await self._connection_repo.consume_oauth_state(provider="telegram", state=state_token)
if state is None:
await update.message.reply_text("Telegram connection link is invalid or expired.")
return True
owner_user_id = state["owner_user_id"]
user_id = str(update.effective_user.id)
chat_id = str(update.effective_chat.id)
connection = await self._connection_repo.upsert_connection(
owner_user_id=owner_user_id,
provider="telegram",
external_account_id=user_id,
external_account_name=self._telegram_display_name(update.effective_user),
workspace_id=chat_id,
workspace_name=None,
metadata={
"chat_id": chat_id,
"chat_type": update.effective_chat.type,
"telegram_username": getattr(update.effective_user, "username", None),
},
status="connected",
)
logger.info("[Telegram] bound chat=%s user=%s to DeerFlow user=%s connection=%s", chat_id, user_id, owner_user_id, connection["id"])
await update.message.reply_text("Telegram connected to DeerFlow.")
return True
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="telegram",
workspace_id=inbound.chat_id,
)
def _get_bot_username(self, context) -> str | None:
bot = getattr(context, "bot", None)
username = getattr(bot, "username", None)
@@ -264,6 +334,11 @@ class TelegramChannel(Channel):
"""Handle /start command."""
if not self._check_user(update.effective_user.id):
return
args = getattr(context, "args", []) if context is not None else []
if args:
handled = await self._bind_connection_from_start_token(update, str(args[0]))
if handled:
return
await update.message.reply_text("Welcome to DeerFlow! Send me a message to start a conversation.\nType /help for available commands.")
async def _process_incoming_with_reply(self, chat_id: str, msg_id: int, inbound: InboundMessage) -> None:
@@ -299,6 +374,7 @@ class TelegramChannel(Channel):
thread_ts=msg_id,
)
inbound.topic_id = topic_id
inbound = await self._attach_connection_identity(inbound)
if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
@@ -341,6 +417,7 @@ class TelegramChannel(Channel):
thread_ts=msg_id,
)
inbound.topic_id = topic_id
inbound = await self._attach_connection_identity(inbound)
if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
+60 -2
View File
@@ -22,8 +22,9 @@ from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -253,6 +254,7 @@ class WechatChannel(Channel):
self._state_dir = self._resolve_state_dir(config.get("state_dir"))
self._cursor_path = self._state_dir / "wechat-getupdates.json" if self._state_dir else None
self._auth_path = self._state_dir / "wechat-auth.json" if self._state_dir else None
self._connection_repo = config.get("connection_repo")
self._load_state()
async def start(self) -> None:
@@ -617,6 +619,16 @@ class WechatChannel(Channel):
if thread_ts:
self._context_tokens_by_thread[thread_ts] = context_token
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
handled = await self._bind_connection_from_connect_code(
chat_id=chat_id,
context_token=context_token,
code=connect_code,
)
if handled:
return
inbound = self._make_inbound(
chat_id=chat_id,
user_id=chat_id,
@@ -632,8 +644,54 @@ class WechatChannel(Channel):
},
)
inbound.topic_id = None
inbound = await self._attach_connection_identity(inbound)
await self.bus.publish_inbound(inbound)
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="wechat",
workspace_id=inbound.chat_id,
)
async def _bind_connection_from_connect_code(self, *, chat_id: str, context_token: str, code: str) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="wechat", state=code)
if state is None:
await self._send_connection_reply(chat_id, context_token, "WeChat connection code is invalid or expired.")
return True
if not chat_id:
await self._send_connection_reply(chat_id, context_token, "WeChat connection could not be completed from this message.")
return True
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="wechat",
external_account_id=chat_id,
workspace_id=chat_id,
metadata={
"context_token": context_token,
},
status="connected",
)
await self._send_connection_reply(chat_id, context_token, "WeChat connected to DeerFlow.")
return True
async def _send_connection_reply(self, chat_id: str, context_token: str, text: str) -> None:
if not context_token:
return
await self._send_text_message(
chat_id=chat_id,
context_token=context_token,
text=text,
client_id_prefix="deerflow-connect",
max_retries=1,
)
async def _ensure_authenticated(self) -> bool:
async with self._auth_lock:
if self._bot_token:
+58 -1
View File
@@ -8,8 +8,10 @@ from collections.abc import Awaitable, Callable
from typing import Any, cast
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import (
InboundMessage,
InboundMessageType,
MessageBus,
OutboundMessage,
@@ -29,6 +31,7 @@ class WeComChannel(Channel):
self._ws_frames: dict[str, dict[str, Any]] = {}
self._ws_stream_ids: dict[str, str] = {}
self._working_message = "Working on it..."
self._connection_repo = config.get("connection_repo")
@property
def supports_streaming(self) -> bool:
@@ -271,6 +274,16 @@ class WeComChannel(Channel):
user_id = (body.get("from") or {}).get("userid")
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
handled = await self._bind_connection_from_connect_code(
frame=frame,
user_id=str(user_id or ""),
code=connect_code,
)
if handled:
return
inbound_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
inbound = self._make_inbound(
chat_id=user_id, # keep user's conversation in memory
@@ -292,8 +305,52 @@ class WeComChannel(Channel):
except Exception:
pass
inbound = await self._attach_connection_identity(inbound)
await self.bus.publish_inbound(inbound)
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="wecom",
workspace_id=str(inbound.metadata.get("aibotid") or "") or None,
fallback_without_workspace=True,
)
async def _bind_connection_from_connect_code(self, *, frame: dict[str, Any], user_id: str, code: str) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="wecom", state=code)
if state is None:
await self._send_connection_reply(frame, "WeCom connection code is invalid or expired.")
return True
if not user_id:
await self._send_connection_reply(frame, "WeCom connection could not be completed from this message.")
return True
body = frame.get("body", {}) or {}
workspace_id = str(body.get("aibotid") or "") or None
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="wecom",
external_account_id=user_id,
workspace_id=workspace_id,
metadata={
"aibotid": workspace_id,
"chattype": body.get("chattype"),
},
status="connected",
)
await self._send_connection_reply(frame, "WeCom connected to DeerFlow.")
return True
async def _send_connection_reply(self, frame: dict[str, Any], text: str) -> None:
if not self._ws_client:
return
await self._ws_client.reply(frame, {"msgtype": "text", "text": {"content": text}})
async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
if not self._ws_client:
return
+14 -2
View File
@@ -6,6 +6,7 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.gateway.auth_disabled import warn_if_auth_disabled_enabled
from app.gateway.auth_middleware import AuthMiddleware
from app.gateway.config import get_gateway_config
from app.gateway.csrf_middleware import CSRFMiddleware, get_configured_cors_origins
@@ -15,6 +16,7 @@ from app.gateway.routers import (
artifacts,
assistants_compat,
auth,
channel_connections,
channels,
feedback,
mcp,
@@ -172,6 +174,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
startup_config = get_app_config()
apply_logging_level(startup_config.log_level)
logger.info("Configuration loaded successfully")
warn_if_auth_disabled_enabled()
except Exception as e:
error_msg = f"Failed to load configuration during gateway startup: {e}"
logger.exception(error_msg)
@@ -182,6 +185,12 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# Pre-warm tiktoken encoding cache so the first memory-injection request
# never blocks on the BPE data download (which hits an OpenAI/Azure URL
# that may be unreachable in restricted networks — see issue #3402).
# When memory.token_counting is "char", token counting never touches
# tiktoken, so skip the warm-up entirely (avoids even the 5s probe in
# network-restricted deployments — see issue #3429).
if startup_config.memory.token_counting == "char":
logger.info("memory.token_counting='char'; skipping tiktoken warm-up (network-free token estimation)")
else:
try:
from deerflow.agents.memory.prompt import warm_tiktoken_cache
@@ -192,9 +201,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
if warmed:
logger.info("tiktoken encoding cache warmed successfully")
else:
logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback")
logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback until tiktoken loads successfully")
except TimeoutError:
logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback")
logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback until tiktoken loads successfully")
except Exception:
logger.warning("tiktoken warm-up skipped", exc_info=True)
@@ -376,6 +385,9 @@ This gateway provides runtime endpoints for agent runs plus custom endpoints for
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
app.include_router(suggestions.router)
# User-facing IM channel connection API is mounted at /api/channels
app.include_router(channel_connections.router)
# Channels API is mounted at /api/channels
app.include_router(channels.router)
+56
View File
@@ -0,0 +1,56 @@
"""Shared helpers for local/E2E auth-disabled mode."""
from __future__ import annotations
import logging
import os
from types import SimpleNamespace
from deerflow.runtime.user_context import DEFAULT_USER_ID
AUTH_DISABLED_ENV_VAR = "DEER_FLOW_AUTH_DISABLED"
AUTH_DISABLED_USER_ID = DEFAULT_USER_ID
AUTH_DISABLED_USER_EMAIL = "default@test.local"
AUTH_SOURCE_SESSION = "session"
AUTH_SOURCE_INTERNAL = "internal"
AUTH_SOURCE_AUTH_DISABLED = "auth_disabled"
_PRODUCTION_ENV_VARS: tuple[str, ...] = ("DEER_FLOW_ENV", "ENVIRONMENT")
_PRODUCTION_ENV_VALUES: frozenset[str] = frozenset({"prod", "production"})
logger = logging.getLogger(__name__)
def is_explicit_production_environment() -> bool:
return any(os.environ.get(name, "").strip().lower() in _PRODUCTION_ENV_VALUES for name in _PRODUCTION_ENV_VARS)
def is_auth_disabled_requested() -> bool:
return os.environ.get(AUTH_DISABLED_ENV_VAR) == "1"
def is_auth_disabled() -> bool:
return is_auth_disabled_requested() and not is_explicit_production_environment()
def warn_if_auth_disabled_enabled() -> None:
if not is_auth_disabled():
return
logger.warning(
"%s=1 is active: authentication is bypassed and anonymous requests run as synthetic admin user %r. Do not enable this in shared or production deployments.",
AUTH_DISABLED_ENV_VAR,
AUTH_DISABLED_USER_ID,
)
def get_auth_disabled_user():
return SimpleNamespace(
id=AUTH_DISABLED_USER_ID,
email=AUTH_DISABLED_USER_EMAIL,
password_hash=None,
system_role="admin",
needs_setup=False,
token_version=0,
)
+31 -14
View File
@@ -17,6 +17,13 @@ from starlette.responses import JSONResponse
from starlette.types import ASGIApp
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
from app.gateway.auth_disabled import (
AUTH_SOURCE_AUTH_DISABLED,
AUTH_SOURCE_INTERNAL,
AUTH_SOURCE_SESSION,
get_auth_disabled_user,
is_auth_disabled,
)
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token
from deerflow.runtime.user_context import reset_current_user, set_current_user
@@ -80,18 +87,14 @@ class AuthMiddleware(BaseHTTPMiddleware):
if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)):
internal_user = get_internal_user()
# Non-public path: require session cookie
if internal_user is None and not request.cookies.get("access_token"):
return JSONResponse(
status_code=401,
content={
"detail": AuthErrorResponse(
code=AuthErrorCode.NOT_AUTHENTICATED,
message="Authentication required",
).model_dump()
},
)
auth_source = AUTH_SOURCE_SESSION
access_token = request.cookies.get("access_token")
# Non-public path: require session cookie
if internal_user is not None:
user = internal_user
auth_source = AUTH_SOURCE_INTERNAL
elif access_token:
# Strict JWT validation: reject junk/expired tokens with 401
# right here instead of silently passing through. This closes
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
@@ -105,19 +108,33 @@ class AuthMiddleware(BaseHTTPMiddleware):
# bubble up, so we catch and render it as JSONResponse here.
from app.gateway.deps import get_current_user_from_request
if internal_user is not None:
user = internal_user
else:
try:
user = await get_current_user_from_request(request)
except HTTPException as exc:
if not is_auth_disabled():
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
user = get_auth_disabled_user()
auth_source = AUTH_SOURCE_AUTH_DISABLED
elif is_auth_disabled():
user = get_auth_disabled_user()
auth_source = AUTH_SOURCE_AUTH_DISABLED
else:
return JSONResponse(
status_code=401,
content={
"detail": AuthErrorResponse(
code=AuthErrorCode.NOT_AUTHENTICATED,
message="Authentication required",
).model_dump()
},
)
# Stamp both request.state.user (for the contextvar pattern)
# and request.state.auth (so @require_permission's "auth is
# None" branch short-circuits instead of running the entire
# JWT-decode + DB-lookup pipeline a second time per request).
request.state.user = user
request.state.auth_source = auth_source
request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
token = set_current_user(user)
try:
+5
View File
@@ -276,6 +276,11 @@ def require_permission(
# strict-deny rather than strict-allow — only an *existing*
# row with a *different* user_id triggers 404.
if owner_check:
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
if getattr(auth.user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
return await func(*args, **kwargs)
thread_id = kwargs.get("thread_id")
if thread_id is None:
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
+5
View File
@@ -14,6 +14,8 @@ from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from starlette.types import ASGIApp
from app.gateway.auth_disabled import is_auth_disabled
CSRF_COOKIE_NAME = "csrf_token"
CSRF_HEADER_NAME = "X-CSRF-Token"
CSRF_TOKEN_LENGTH = 64 # bytes
@@ -38,6 +40,9 @@ def should_check_csrf(request: Request) -> bool:
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
return False
if is_auth_disabled():
return False
path = request.url.path.rstrip("/")
# Exempt /api/v1/auth/me endpoint
if path == "/api/v1/auth/me":
+11
View File
@@ -331,6 +331,17 @@ async def get_current_user_from_request(request: Request):
Raises HTTPException 401 if not authenticated.
"""
state = getattr(request, "state", None)
state_user = getattr(state, "user", None)
from app.gateway.auth_disabled import AUTH_SOURCE_AUTH_DISABLED, AUTH_SOURCE_INTERNAL, AUTH_SOURCE_SESSION
if state_user is not None and getattr(state, "auth_source", None) in {
AUTH_SOURCE_SESSION,
AUTH_SOURCE_AUTH_DISABLED,
AUTH_SOURCE_INTERNAL,
}:
return state_user
from app.gateway.auth import decode_token
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
+25 -2
View File
@@ -5,10 +5,12 @@ from __future__ import annotations
import os
import secrets
from types import SimpleNamespace
from typing import Any
from deerflow.runtime.user_context import DEFAULT_USER_ID
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
INTERNAL_OWNER_USER_ID_HEADER_NAME = "X-DeerFlow-Owner-User-Id"
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
INTERNAL_SYSTEM_ROLE = "internal"
@@ -23,9 +25,12 @@ def _load_internal_auth_token() -> str:
_INTERNAL_AUTH_TOKEN = _load_internal_auth_token()
def create_internal_auth_headers() -> dict[str, str]:
def create_internal_auth_headers(*, owner_user_id: str | None = None) -> dict[str, str]:
"""Return headers that authenticate trusted Gateway internal calls."""
return {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
headers = {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
if owner_user_id:
headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] = owner_user_id
return headers
def is_valid_internal_auth_token(token: str | None) -> bool:
@@ -36,3 +41,21 @@ def is_valid_internal_auth_token(token: str | None) -> bool:
def get_internal_user():
"""Return the synthetic user used for trusted internal channel calls."""
return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE)
def get_trusted_internal_owner_user_id(request: Any) -> str | None:
"""Return the owner override for a trusted internal request, if present.
The header is ignored for normal browser/API callers. It is only honored
after ``AuthMiddleware`` has validated the internal auth token and stamped
the synthetic internal user onto ``request.state.user``.
"""
user = getattr(getattr(request, "state", None), "user", None)
if getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE:
return None
owner_user_id = request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME)
if not owner_user_id:
return None
owner_user_id = owner_user_id.strip()
return owner_user_id or None
+7
View File
@@ -20,6 +20,7 @@ from langgraph_sdk import Auth
from app.gateway.auth.errors import TokenError
from app.gateway.auth.jwt import decode_token
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled
from app.gateway.deps import get_local_provider
auth = Auth()
@@ -38,6 +39,9 @@ def _check_csrf(request) -> None:
if method.upper() not in _CSRF_METHODS:
return
if is_auth_disabled():
return
cookie_token = request.cookies.get("csrf_token")
header_token = request.headers.get("x-csrf-token")
@@ -66,6 +70,9 @@ async def authenticate(request):
# are rejected early, even if the cookie carries a valid JWT.
_check_csrf(request)
if is_auth_disabled():
return AUTH_DISABLED_USER_ID
token = request.cookies.get("access_token")
if not token:
raise Auth.exceptions.HTTPException(
+10
View File
@@ -341,9 +341,19 @@ async def change_password(request: Request, response: Response, body: ChangePass
- Re-issues session cookie with new token_version
"""
from app.gateway.auth.password import hash_password_async, verify_password_async
from app.gateway.auth_disabled import AUTH_SOURCE_AUTH_DISABLED
user = await get_current_user_from_request(request)
if getattr(request.state, "auth_source", None) == AUTH_SOURCE_AUTH_DISABLED:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=AuthErrorResponse(
code=AuthErrorCode.INVALID_CREDENTIALS,
message="Password changes are not available when DEER_FLOW_AUTH_DISABLED=1.",
).model_dump(),
)
if user.password_hash is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
@@ -0,0 +1,635 @@
"""Browser-facing APIs for user-owned IM channel bindings."""
from __future__ import annotations
import logging
import secrets
from datetime import UTC, datetime, timedelta
from typing import Any
from fastapi import APIRouter, HTTPException, Request, Response
from pydantic import BaseModel, Field
from app.channels.runtime_config_store import (
ChannelRuntimeConfigStore,
apply_runtime_connection_config,
merge_runtime_channel_configs,
)
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
from deerflow.persistence.channel_connections import ChannelConnectionRepository
from deerflow.persistence.engine import get_session_factory
router = APIRouter(prefix="/api/channels", tags=["channel-connections"])
logger = logging.getLogger(__name__)
_STATE_TTL_SECONDS = 600
_MASKED_CREDENTIAL_VALUE = "********"
class ChannelCredentialFieldResponse(BaseModel):
name: str
label: str
type: str = "text"
required: bool = True
class ChannelProviderResponse(BaseModel):
provider: str
display_name: str
enabled: bool
configured: bool
connectable: bool
unavailable_reason: str | None = None
auth_mode: str
connection_status: str
credential_fields: list[ChannelCredentialFieldResponse] = Field(default_factory=list)
credential_values: dict[str, str] = Field(default_factory=dict)
class ChannelProvidersResponse(BaseModel):
enabled: bool
providers: list[ChannelProviderResponse]
class ChannelConnectionResponse(BaseModel):
id: str
provider: str
status: str
external_account_id: str | None = None
external_account_name: str | None = None
workspace_id: str | None = None
workspace_name: str | None = None
scopes: list[str] = Field(default_factory=list)
metadata: dict[str, Any] = Field(default_factory=dict)
class ChannelConnectionsResponse(BaseModel):
connections: list[ChannelConnectionResponse]
class ChannelConnectResponse(BaseModel):
provider: str
mode: str
url: str | None = None
code: str
instruction: str
expires_in: int
class ChannelRuntimeConfigRequest(BaseModel):
values: dict[str, str] = Field(default_factory=dict)
_PROVIDER_META: dict[str, dict[str, str]] = {
"telegram": {"display_name": "Telegram", "auth_mode": "deep_link"},
"slack": {"display_name": "Slack", "auth_mode": "binding_code"},
"discord": {"display_name": "Discord", "auth_mode": "binding_code"},
"feishu": {"display_name": "Feishu", "auth_mode": "binding_code"},
"dingtalk": {"display_name": "DingTalk", "auth_mode": "binding_code"},
"wechat": {"display_name": "WeChat", "auth_mode": "binding_code"},
"wecom": {"display_name": "WeCom", "auth_mode": "binding_code"},
}
_CREDENTIAL_FIELDS: dict[str, tuple[dict[str, str], ...]] = {
"telegram": (
{"name": "bot_token", "label": "Bot token", "type": "password"},
{"name": "bot_username", "label": "Bot username", "type": "text"},
),
"slack": (
{"name": "bot_token", "label": "Bot token", "type": "password"},
{"name": "app_token", "label": "App token", "type": "password"},
),
"discord": ({"name": "bot_token", "label": "Bot token", "type": "password"},),
"feishu": (
{"name": "app_id", "label": "App ID", "type": "text"},
{"name": "app_secret", "label": "App secret", "type": "password"},
),
"dingtalk": (
{"name": "client_id", "label": "Client ID", "type": "text"},
{"name": "client_secret", "label": "Client secret", "type": "password"},
),
"wechat": ({"name": "bot_token", "label": "Bot token", "type": "password"},),
"wecom": (
{"name": "bot_id", "label": "Bot ID", "type": "text"},
{"name": "bot_secret", "label": "Bot secret", "type": "password"},
),
}
_RUNTIME_REQUIREMENTS: dict[str, tuple[str, ...]] = {
"telegram": ("bot_token",),
"slack": ("bot_token", "app_token"),
"discord": ("bot_token",),
"feishu": ("app_id", "app_secret"),
"dingtalk": ("client_id", "client_secret"),
"wechat": ("bot_token",),
"wecom": ("bot_id", "bot_secret"),
}
def _get_user_id(request: Request) -> str:
user = getattr(request.state, "user", None)
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
return str(user.id)
def _get_app_config():
from deerflow.config.app_config import get_app_config
return get_app_config()
def _get_runtime_config_store(request: Request) -> ChannelRuntimeConfigStore:
store = getattr(request.app.state, "channel_runtime_config_store", None)
if isinstance(store, ChannelRuntimeConfigStore):
return store
store = ChannelRuntimeConfigStore()
request.app.state.channel_runtime_config_store = store
return store
def _get_channel_connections_config(request: Request) -> ChannelConnectionsConfig:
config = getattr(request.app.state, "channel_connections_config", None)
if not isinstance(config, ChannelConnectionsConfig):
config = _get_app_config().channel_connections
config = apply_runtime_connection_config(config, store=_get_runtime_config_store(request))
request.app.state.channel_connections_config = config
return config
def _get_channels_config(request: Request) -> dict[str, Any]:
state_config = getattr(request.app.state, "channels_config", None)
if isinstance(state_config, dict):
return state_config
result = _load_channels_config(request, _get_channel_connections_config(request))
request.app.state.channels_config = result
return result
def _load_channels_config(request: Request, config: ChannelConnectionsConfig) -> dict[str, Any]:
app_config = _get_app_config()
extra = app_config.model_extra or {}
channels_config = extra.get("channels")
result = dict(channels_config) if isinstance(channels_config, dict) else {}
merge_runtime_channel_configs(
result,
config,
store=_get_runtime_config_store(request),
)
return result
def _get_repository(request: Request, config: ChannelConnectionsConfig) -> ChannelConnectionRepository:
repo = getattr(request.app.state, "channel_connection_repo", None)
if isinstance(repo, ChannelConnectionRepository):
return repo
sf = get_session_factory()
if sf is None:
raise HTTPException(status_code=503, detail="Channel connection persistence is not available")
repo = ChannelConnectionRepository(sf)
request.app.state.channel_connection_repo = repo
return repo
def _provider_config(config: ChannelConnectionsConfig, provider: str):
provider_config = getattr(config, provider, None)
if provider_config is None:
raise HTTPException(status_code=404, detail="Unknown channel provider")
return provider_config
def _runtime_channel_configured(provider: str, channels_config: dict[str, Any]) -> bool:
runtime_config = channels_config.get(provider)
if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False):
return False
return all(str(runtime_config.get(key) or "").strip() for key in _RUNTIME_REQUIREMENTS[provider])
def _runtime_unavailable_reason(provider: str) -> str:
meta = _PROVIDER_META.get(provider)
display_name = meta["display_name"] if meta else provider
return f"Enter the required {display_name} credentials to connect this channel."
def _runtime_not_running_reason(provider: str) -> str:
meta = _PROVIDER_META.get(provider)
display_name = meta["display_name"] if meta else provider
return f"{display_name} channel is configured but is not running. Check the credentials and service logs."
def _runtime_channel_running(provider: str) -> bool | None:
try:
from app.channels.service import get_channel_service
except Exception:
logger.debug("Unable to inspect channel service status", exc_info=True)
return None
service = get_channel_service()
if service is None:
return None
try:
status = service.get_status()
except Exception:
logger.debug("Unable to read channel service status", exc_info=True)
return None
if not status.get("service_running"):
return False
channel_status = status.get("channels", {}).get(provider)
if not isinstance(channel_status, dict):
return None
return bool(channel_status.get("running"))
async def _ensure_runtime_channel_ready_if_available(
provider: str,
channels_config: dict[str, Any],
) -> bool | None:
runtime_config = channels_config.get(provider)
if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False):
return None
try:
from app.channels.service import get_channel_service
except Exception:
logger.debug("Unable to import channel service for readiness reconciliation", exc_info=True)
return None
service = get_channel_service()
if service is None:
return None
ensure_channel_ready = getattr(service, "ensure_channel_ready", None)
if ensure_channel_ready is None:
return None
try:
return await ensure_channel_ready(provider, runtime_config)
except Exception:
logger.exception("Failed to reconcile runtime channel readiness")
return False
def _provider_unavailable_reason(
config: ChannelConnectionsConfig,
channels_config: dict[str, Any],
provider: str,
) -> str | None:
provider_config = _provider_config(config, provider)
if not provider_config.enabled:
return None
if not provider_config.configured:
return _runtime_unavailable_reason(provider)
if not _runtime_channel_configured(provider, channels_config):
return _runtime_unavailable_reason(provider)
if _runtime_channel_running(provider) is False:
return _runtime_not_running_reason(provider)
return None
def _provider_status(
config: ChannelConnectionsConfig,
channels_config: dict[str, Any],
provider: str,
) -> tuple[dict[str, bool], str | None]:
declared = config.provider_status(provider)
unavailable_reason = _provider_unavailable_reason(config, channels_config, provider)
configured = declared["configured"] and _runtime_channel_configured(provider, channels_config)
return {"enabled": declared["enabled"], "configured": configured}, unavailable_reason
def _new_binding_code() -> str:
return secrets.token_urlsafe(16)
async def _create_state(
repo: ChannelConnectionRepository,
*,
owner_user_id: str,
provider: str,
) -> str:
state = _new_binding_code()
await repo.create_oauth_state(
owner_user_id=owner_user_id,
provider=provider,
state=state,
expires_at=datetime.now(UTC) + timedelta(seconds=_STATE_TTL_SECONDS),
)
return state
def _connect_instruction(provider: str, code: str) -> str:
if provider == "telegram":
return f"Send /start {code} to the DeerFlow Telegram bot."
meta = _PROVIDER_META.get(provider)
if meta is None:
raise HTTPException(status_code=404, detail="Unknown channel provider")
return f"Send /connect {code} to the DeerFlow {meta['display_name']} bot."
def _connect_url(config: ChannelConnectionsConfig, provider: str, code: str) -> str | None:
if provider == "telegram":
provider_config = _provider_config(config, provider)
return f"https://t.me/{provider_config.bot_username}?start={code}"
if _PROVIDER_META.get(provider, {}).get("auth_mode") == "binding_code":
return None
raise HTTPException(status_code=404, detail="Unknown channel provider")
def _connection_updated_at(connection: dict[str, Any]) -> datetime:
value = connection.get("updated_at")
if isinstance(value, datetime):
return value if value.tzinfo is not None else value.replace(tzinfo=UTC)
if isinstance(value, str) and value:
try:
return datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
pass
return datetime.min.replace(tzinfo=UTC)
def _newest_connection_by_provider(connections: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
by_provider: dict[str, dict[str, Any]] = {}
for item in connections:
existing = by_provider.get(item["provider"])
if existing is None or _connection_updated_at(item) > _connection_updated_at(existing):
by_provider[item["provider"]] = item
return by_provider
def _credential_fields(provider: str) -> list[ChannelCredentialFieldResponse]:
fields = _CREDENTIAL_FIELDS.get(provider)
if fields is None:
raise HTTPException(status_code=404, detail="Unknown channel provider")
return [ChannelCredentialFieldResponse(**field) for field in fields]
def _credential_values(provider: str, channels_config: dict[str, Any]) -> dict[str, str]:
runtime_config = channels_config.get(provider)
if not isinstance(runtime_config, dict):
return {}
values: dict[str, str] = {}
for field in _credential_fields(provider):
value = str(runtime_config.get(field.name) or "").strip()
if not value:
continue
values[field.name] = _MASKED_CREDENTIAL_VALUE if field.type == "password" else value
return values
def _provider_response(
config: ChannelConnectionsConfig,
channels_config: dict[str, Any],
provider: str,
meta: dict[str, str],
connection: dict[str, Any] | None = None,
) -> ChannelProviderResponse:
status, unavailable_reason = _provider_status(config, channels_config, provider)
if connection:
connection_status = connection["status"]
elif status["configured"] and unavailable_reason is None:
connection_status = "connected"
else:
connection_status = "not_connected"
credential_values = _credential_values(provider, channels_config)
if provider == "telegram" and not credential_values.get("bot_username"):
bot_username = str(_provider_config(config, provider).bot_username or "").strip()
if bot_username:
credential_values["bot_username"] = bot_username
return ChannelProviderResponse(
provider=provider,
display_name=meta["display_name"],
enabled=status["enabled"],
configured=status["configured"],
connectable=status["enabled"] and status["configured"] and unavailable_reason is None,
unavailable_reason=unavailable_reason,
auth_mode=meta["auth_mode"],
connection_status=connection_status,
credential_fields=_credential_fields(provider),
credential_values=credential_values,
)
def _required_runtime_values(
provider: str,
values: dict[str, str],
existing_config: dict[str, Any] | None = None,
) -> dict[str, str]:
fields = _credential_fields(provider)
cleaned: dict[str, str] = {}
missing: list[str] = []
existing_config = existing_config or {}
for field in fields:
raw_value = values.get(field.name, "")
if field.type == "password" and raw_value == _MASKED_CREDENTIAL_VALUE:
existing_value = str(existing_config.get(field.name) or "").strip()
if existing_value:
cleaned[field.name] = existing_value
continue
value = raw_value.strip() if isinstance(raw_value, str) else str(raw_value or "").strip()
if field.required and not value:
missing.append(field.label)
cleaned[field.name] = value
if missing:
raise HTTPException(status_code=400, detail=f"Missing required channel configuration: {', '.join(missing)}")
return cleaned
async def _restart_runtime_channel_if_available(provider: str, runtime_config: dict[str, Any]) -> bool | None:
try:
from app.channels.service import get_channel_service
except Exception:
logger.exception("Failed to import channel service while configuring a runtime channel")
return None
service = get_channel_service()
if service is None:
return None
return await service.configure_channel(provider, runtime_config)
async def _sync_runtime_channel_after_removal(provider: str, channels_config: dict[str, Any]) -> bool | None:
try:
from app.channels.service import get_channel_service
except Exception:
logger.exception("Failed to import channel service while disconnecting a runtime channel")
return None
service = get_channel_service()
if service is None:
return None
runtime_config = channels_config.get(provider)
if isinstance(runtime_config, dict) and runtime_config.get("enabled", False):
return await service.configure_channel(provider, runtime_config)
return await service.remove_channel(provider)
@router.get("/providers", response_model=ChannelProvidersResponse)
async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
config = _get_channel_connections_config(request)
channels_config = _get_channels_config(request)
repo = None
if config.enabled:
try:
repo = _get_repository(request, config)
except HTTPException as exc:
if exc.status_code != 503:
raise
owner_user_id = _get_user_id(request)
connections = await repo.list_connections(owner_user_id) if repo is not None else []
by_provider = _newest_connection_by_provider(connections)
providers: list[ChannelProviderResponse] = []
for provider, meta in _PROVIDER_META.items():
if not config.provider_status(provider)["enabled"]:
continue
if _runtime_channel_configured(provider, channels_config):
await _ensure_runtime_channel_ready_if_available(provider, channels_config)
connection = by_provider.get(provider)
providers.append(_provider_response(config, channels_config, provider, meta, connection))
return ChannelProvidersResponse(enabled=config.enabled, providers=providers)
@router.get("/connections", response_model=ChannelConnectionsResponse)
async def get_channel_connections(request: Request) -> ChannelConnectionsResponse:
config = _get_channel_connections_config(request)
if not config.enabled:
return ChannelConnectionsResponse(connections=[])
repo = _get_repository(request, config)
rows = await repo.list_connections(_get_user_id(request))
return ChannelConnectionsResponse(connections=[ChannelConnectionResponse(**row) for row in rows])
@router.delete("/connections/{connection_id}", status_code=204)
async def disconnect_channel_connection(connection_id: str, request: Request) -> Response:
config = _get_channel_connections_config(request)
if not config.enabled:
raise HTTPException(status_code=400, detail="Channel connections are disabled")
repo = _get_repository(request, config)
disconnected = await repo.disconnect_connection(
connection_id=connection_id,
owner_user_id=_get_user_id(request),
)
if not disconnected:
raise HTTPException(status_code=404, detail="Channel connection not found")
return Response(status_code=204)
@router.delete("/{provider}/runtime-config", response_model=ChannelProviderResponse)
async def disconnect_channel_provider_runtime(provider: str, request: Request) -> ChannelProviderResponse:
config = _get_channel_connections_config(request)
if not config.enabled:
raise HTTPException(status_code=400, detail="Channel connections are disabled")
provider_config = _provider_config(config, provider)
if not provider_config.enabled:
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
owner_user_id = _get_user_id(request)
try:
repo = _get_repository(request, config)
except HTTPException as exc:
if exc.status_code != 503:
raise
repo = None
if repo is not None:
for connection in await repo.list_connections(owner_user_id):
if connection["provider"] == provider and connection["status"] != "revoked":
await repo.disconnect_connection(
connection_id=connection["id"],
owner_user_id=owner_user_id,
)
_get_runtime_config_store(request).set_provider_disconnected(provider)
channels_config = _load_channels_config(request, config)
request.app.state.channels_config = channels_config
stopped = await _sync_runtime_channel_after_removal(provider, channels_config)
if stopped is False:
display_name = _PROVIDER_META[provider]["display_name"]
raise HTTPException(status_code=400, detail=f"Failed to stop {display_name} channel. Try again.")
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
@router.post("/{provider}/connect", response_model=ChannelConnectResponse)
async def connect_channel_provider(provider: str, request: Request) -> ChannelConnectResponse:
config = _get_channel_connections_config(request)
channels_config = _get_channels_config(request)
if not config.enabled:
raise HTTPException(status_code=400, detail="Channel connections are disabled")
provider_config = _provider_config(config, provider)
if provider_config.enabled and _runtime_channel_configured(provider, channels_config):
await _ensure_runtime_channel_ready_if_available(provider, channels_config)
status, unavailable_reason = _provider_status(config, channels_config, provider)
if not status["enabled"]:
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
if unavailable_reason:
raise HTTPException(status_code=400, detail=unavailable_reason)
if not status["configured"]:
raise HTTPException(status_code=400, detail="Channel provider is not configured")
repo = _get_repository(request, config)
code = await _create_state(
repo,
owner_user_id=_get_user_id(request),
provider=provider,
)
return ChannelConnectResponse(
provider=provider,
mode=_PROVIDER_META[provider]["auth_mode"],
url=_connect_url(config, provider, code),
code=code,
instruction=_connect_instruction(provider, code),
expires_in=_STATE_TTL_SECONDS,
)
@router.post("/{provider}/runtime-config", response_model=ChannelProviderResponse)
async def configure_channel_provider_runtime(
provider: str,
body: ChannelRuntimeConfigRequest,
request: Request,
) -> ChannelProviderResponse:
config = _get_channel_connections_config(request)
if not config.enabled:
raise HTTPException(status_code=400, detail="Channel connections are disabled")
provider_config = _provider_config(config, provider)
if not provider_config.enabled:
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
channels_config = _get_channels_config(request)
existing = channels_config.get(provider)
runtime_config = dict(existing) if isinstance(existing, dict) else {}
values = _required_runtime_values(provider, body.values, runtime_config)
runtime_config["enabled"] = True
for key in _RUNTIME_REQUIREMENTS[provider]:
runtime_config[key] = values[key]
if provider == "telegram":
runtime_config["bot_username"] = values["bot_username"]
provider_config.bot_username = values["bot_username"]
request.app.state.channel_connections_config = config
channels_config[provider] = runtime_config
request.app.state.channels_config = channels_config
started = await _restart_runtime_channel_if_available(provider, runtime_config)
if started is False:
display_name = _PROVIDER_META[provider]["display_name"]
raise HTTPException(status_code=400, detail=f"Failed to start {display_name} channel. Check the values and try again.")
_get_runtime_config_store(request).set_provider_config(provider, runtime_config)
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
+5 -1
View File
@@ -98,6 +98,7 @@ class MemoryConfigResponse(BaseModel):
fact_confidence_threshold: float = Field(..., description="Minimum confidence threshold for facts")
injection_enabled: bool = Field(..., description="Whether memory injection is enabled")
max_injection_tokens: int = Field(..., description="Maximum tokens for memory injection")
token_counting: str = Field(..., description="Token counting strategy for memory injection ('tiktoken' or 'char')")
class MemoryStatusResponse(BaseModel):
@@ -310,7 +311,8 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
"max_facts": 100,
"fact_confidence_threshold": 0.7,
"injection_enabled": true,
"max_injection_tokens": 2000
"max_injection_tokens": 2000,
"token_counting": "tiktoken"
}
```
"""
@@ -323,6 +325,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
fact_confidence_threshold=config.fact_confidence_threshold,
injection_enabled=config.injection_enabled,
max_injection_tokens=config.max_injection_tokens,
token_counting=config.token_counting,
)
@@ -351,6 +354,7 @@ async def get_memory_status() -> MemoryStatusResponse:
fact_confidence_threshold=config.fact_confidence_threshold,
injection_enabled=config.injection_enabled,
max_injection_tokens=config.max_injection_tokens,
token_counting=config.token_counting,
),
data=MemoryResponse(**memory_data),
)
+11 -1
View File
@@ -22,6 +22,7 @@ from pydantic import BaseModel, Field, field_validator
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer
from app.gateway.internal_auth import get_trusted_internal_owner_user_id
from app.gateway.utils import sanitize_log_param
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime import serialize_channel_values
@@ -257,11 +258,19 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
thread_store = get_thread_store(request)
thread_id = body.thread_id or str(uuid.uuid4())
now = now_iso()
thread_owner_user_id = get_trusted_internal_owner_user_id(request)
thread_owner_kwargs = {"user_id": thread_owner_user_id} if thread_owner_user_id else {}
# ``body.metadata`` is already stripped of server-reserved keys by
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
# Idempotency: return existing record when already present
existing_record = await thread_store.get(thread_id)
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
if existing_record is None and thread_owner_user_id:
unscoped_record = await thread_store.get(thread_id, user_id=None)
if unscoped_record is not None:
if unscoped_record.get("user_id") != thread_owner_user_id:
await thread_store.update_owner(thread_id, thread_owner_user_id, user_id=None)
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
if existing_record is not None:
return ThreadResponse(
thread_id=thread_id,
@@ -276,6 +285,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
await thread_store.create(
thread_id,
assistant_id=getattr(body, "assistant_id", None),
**thread_owner_kwargs,
metadata=body.metadata,
)
except Exception:
+31 -1
View File
@@ -12,6 +12,7 @@ import json
import logging
import re
from collections.abc import Mapping
from types import SimpleNamespace
from typing import Any
from fastapi import HTTPException, Request
@@ -19,7 +20,7 @@ from langchain_core.messages import BaseMessage
from langchain_core.messages.utils import convert_to_messages
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE, get_trusted_internal_owner_user_id
from app.gateway.utils import sanitize_log_param
from deerflow.config.app_config import get_app_config
from deerflow.runtime import (
@@ -35,6 +36,7 @@ from deerflow.runtime import (
run_agent,
)
from deerflow.runtime.runs.naming import resolve_root_run_name
from deerflow.runtime.user_context import reset_current_user, set_current_user
logger = logging.getLogger(__name__)
@@ -315,6 +317,24 @@ async def start_run(
detail=f"Model {model_name!r} is not in the configured model allowlist",
)
owner_user_id = get_trusted_internal_owner_user_id(request)
# Stateless run endpoints carry thread_id in the request *body*, so the
# @require_permission(owner_check=True) decorator -- which resolves ownership
# from the path param -- cannot protect them. Enforce thread ownership here,
# before any run is created, so one user cannot start runs on (or read /wait
# checkpoint state from) another user's thread. Missing rows (auto-created
# temp threads) and NULL-owner rows (shared / pre-auth data) stay accessible
# via check_access; only a thread already owned by another user is rejected
# with 404, matching thread_runs.py's anti-enumeration behaviour. Internal
# channel runs act on behalf of IM users they do not own (see
# inject_authenticated_user_context), so the internal system role is exempt.
user = getattr(request.state, "user", None)
if user is not None and getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE:
if not await run_ctx.thread_store.check_access(thread_id, str(user.id)):
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
owner_context_token = set_current_user(SimpleNamespace(id=owner_user_id)) if owner_user_id else None
try:
try:
record = await run_mgr.create_or_reject(
thread_id,
@@ -324,6 +344,7 @@ async def start_run(
kwargs={"input": body.input, "config": body.config},
multitask_strategy=body.multitask_strategy,
model_name=model_name,
user_id=owner_user_id,
)
except ConflictError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
@@ -335,6 +356,12 @@ async def start_run(
# (e.g. stateless runs).
try:
existing = await run_ctx.thread_store.get(thread_id)
if existing is None and owner_user_id:
unscoped_existing = await run_ctx.thread_store.get(thread_id, user_id=None)
if unscoped_existing is not None:
if unscoped_existing.get("user_id") != owner_user_id:
await run_ctx.thread_store.update_owner(thread_id, owner_user_id, user_id=None)
existing = await run_ctx.thread_store.get(thread_id)
if existing is None:
await run_ctx.thread_store.create(
thread_id,
@@ -381,6 +408,9 @@ async def start_run(
# after the run completes.
return record
finally:
if owner_context_token is not None:
reset_current_user(owner_context_token)
async def sse_consumer(
+121
View File
@@ -0,0 +1,121 @@
# IM Channel Connections
DeerFlow supports user-owned IM channel bindings for Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom. The feature reuses the existing `channels.*` runtime configuration, so it works in local and private deployments with the same outbound transports already supported by DeerFlow.
No public IP, OAuth callback URL, or provider webhook is required in this implementation.
## Configuration
Configure the actual IM bots under the existing `channels` block:
```yaml
channels:
telegram:
enabled: true
bot_token: $TELEGRAM_BOT_TOKEN
slack:
enabled: true
bot_token: $SLACK_BOT_TOKEN
app_token: $SLACK_APP_TOKEN
discord:
enabled: true
bot_token: $DISCORD_BOT_TOKEN
feishu:
enabled: true
app_id: $FEISHU_APP_ID
app_secret: $FEISHU_APP_SECRET
dingtalk:
enabled: true
client_id: $DINGTALK_CLIENT_ID
client_secret: $DINGTALK_CLIENT_SECRET
wechat:
enabled: true
bot_token: $WECHAT_BOT_TOKEN
wecom:
enabled: true
bot_id: $WECOM_BOT_ID
bot_secret: $WECOM_BOT_SECRET
```
Then enable user bindings in `channel_connections`:
```yaml
channel_connections:
enabled: true
telegram:
enabled: true
bot_username: $TELEGRAM_BOT_USERNAME
slack:
enabled: true
discord:
enabled: true
feishu:
enabled: true
dingtalk:
enabled: true
wechat:
enabled: true
wecom:
enabled: true
```
`channel_connections` does not duplicate provider secrets. It only controls the browser-facing connect UI and stores per-user binding records. Telegram needs `bot_username` only so the frontend can open a deep link.
## Connect Flow
Telegram:
- The frontend creates a short one-time code.
- The Connect button opens `https://t.me/<bot_username>?start=<code>`.
- The existing Telegram long-polling worker receives `/start <code>` and binds that Telegram chat/user to the current DeerFlow user.
Slack:
- The frontend creates a short one-time code.
- The UI shows `Send /connect <code> to the DeerFlow Slack bot.`
- The existing Slack Socket Mode worker receives the message and binds the Slack user/team to the current DeerFlow user.
Discord:
- The frontend creates a short one-time code.
- The UI shows `Send /connect <code> to the DeerFlow Discord bot.`
- The existing Discord Gateway worker receives the message and binds the Discord user/guild to the current DeerFlow user.
Feishu/Lark, DingTalk, WeChat, and WeCom:
- The frontend creates a short one-time code.
- The UI shows `Send /connect <code> to the DeerFlow <Provider> bot.`
- The already-running long-connection or polling worker receives the message and binds the platform user/workspace identity to the current DeerFlow user.
Codes expire after 10 minutes and are single-use.
## Runtime Model
Connection records live in SQL tables under `deerflow.persistence.channel_connections`:
- `channel_connections`: owner user, provider identity, workspace/guild/team, status, metadata.
- `channel_oauth_states`: one-time connect codes and Telegram deep-link state.
- `channel_conversations`: connection-scoped IM conversation to DeerFlow thread mapping.
- `channel_credentials`: reserved for future provider-token flows, not used by the local/private binding flow.
Incoming messages that resolve to a connection carry `connection_id`, `owner_user_id`, and `workspace_id`. `ChannelManager` uses `owner_user_id` as the DeerFlow run user id and preserves the raw platform user id as `channel_user_id`.
## Security Notes
- Browser APIs remain authenticated and CSRF-protected.
- Connect codes are random, short-lived, and single-use.
- Provider bot tokens remain in `channels.*` and are never returned to the browser.
- This implementation does not add public provider callback or webhook routes.
+2 -1
View File
@@ -31,7 +31,8 @@ Current injection format:
Token counting:
- Uses `tiktoken` (`cl100k_base`) when available
- Falls back to `len(text) // 4` if tokenizer import fails
- Falls back to a network-free CJK-aware character estimate if tokenizer import or encoding load fails
(CJK characters count as ~2 chars/token, other characters as ~4 chars/token)
## Known Gap
@@ -586,7 +586,11 @@ def _get_memory_context(agent_name: str | None = None, *, app_config: AppConfig
return ""
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
memory_content = format_memory_for_injection(
memory_data,
max_tokens=config.max_injection_tokens,
use_tiktoken=(config.token_counting == "tiktoken"),
)
if not memory_content.strip():
return ""
@@ -5,7 +5,9 @@ from __future__ import annotations
import logging
import math
import re
from typing import Any
import threading
import time
from typing import Any, cast
logger = logging.getLogger(__name__)
@@ -169,7 +171,26 @@ Return ONLY valid JSON."""
# subsequent calls are a dict lookup (no network I/O). Pre-warming at
# startup via :func:`warm_tiktoken_cache` avoids blocking a request on the
# (potentially slow) first ``get_encoding`` call.
_tiktoken_encoding_cache: dict[str, tiktoken.Encoding] = {}
#
# A *failed* load is cached as a ``(None, monotonic_timestamp)`` tuple so that
# a network-restricted environment does not re-attempt the blocking BPE
# download on every subsequent call. After ``_TIKTOKEN_RETRY_COOLDOWN_S`` the
# failure is allowed to expire so a transient network outage can self-heal back
# to accurate tiktoken counting without a process restart. A load already in
# progress is cached as ``_TIKTOKEN_ENCODING_LOADING`` so concurrent callers
# fall back immediately instead of spawning more blocking
# ``tiktoken.get_encoding`` threads. Use the ``memory.token_counting: char``
# config to skip tiktoken entirely.
_TIKTOKEN_ENCODING_MISSING = object()
_TIKTOKEN_ENCODING_LOADING = object()
# Cooldown before a *failed* tiktoken load is re-attempted. This is an internal
# tuning constant rather than a user-facing config: it only affects how quickly
# the default ``tiktoken`` mode self-heals after a transient network outage.
# Deployments that want to avoid tiktoken's network dependency entirely should
# set ``memory.token_counting: char`` instead of tuning this value.
_TIKTOKEN_RETRY_COOLDOWN_S = 600.0
_tiktoken_encoding_cache: dict[str, Any] = {}
_tiktoken_encoding_cache_lock = threading.Lock()
def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encoding | None:
@@ -181,44 +202,91 @@ def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encod
download can block for tens of minutes before the OS TCP timeout kicks in.
The caller must therefore be prepared for this to block and should run it
off the event loop (e.g. via ``asyncio.to_thread``).
A failed load is remembered (with a timestamp) so subsequent calls fall
back immediately to character-based estimation instead of re-triggering the
blocking download. The failure expires after ``_TIKTOKEN_RETRY_COOLDOWN_S``
so a transient outage can self-heal without a restart. A load already in
progress is also remembered so that a timed-out caller does not leave a
window where later requests start more blocking ``get_encoding`` calls.
"""
if not TIKTOKEN_AVAILABLE:
return None
cached = _tiktoken_encoding_cache.get(encoding_name)
if cached is not None:
return cached
with _tiktoken_encoding_cache_lock:
cached = _tiktoken_encoding_cache.get(encoding_name, _TIKTOKEN_ENCODING_MISSING)
if cached is _TIKTOKEN_ENCODING_LOADING:
return None
if isinstance(cached, tuple):
# Cached failure: (None, failed_at). Retry only after cooldown.
_, failed_at = cached
if time.monotonic() - failed_at < _TIKTOKEN_RETRY_COOLDOWN_S:
return None
cached = _TIKTOKEN_ENCODING_MISSING
if cached is not _TIKTOKEN_ENCODING_MISSING:
return cast("tiktoken.Encoding", cached)
_tiktoken_encoding_cache[encoding_name] = _TIKTOKEN_ENCODING_LOADING
try:
encoding = tiktoken.get_encoding(encoding_name)
_tiktoken_encoding_cache[encoding_name] = encoding
return encoding
except Exception:
logger.warning("Failed to load tiktoken encoding %r; falling back to char-based estimation", encoding_name, exc_info=True)
with _tiktoken_encoding_cache_lock:
_tiktoken_encoding_cache[encoding_name] = (None, time.monotonic())
return None
with _tiktoken_encoding_cache_lock:
_tiktoken_encoding_cache[encoding_name] = encoding
return encoding
def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
def _char_based_token_estimate(text: str) -> int:
"""Network-free token estimate that accounts for CJK density.
The plain ``len(text) // 4`` heuristic is reasonable for English/code
(~4 chars per token) but significantly under-estimates token counts for
Chinese, Japanese, and Korean text, where the ratio is closer to 1.5-2
characters per token. Counting CJK characters separately (~2 chars per
token) avoids over-filling the injection budget for CJK-heavy memory
content.
"""
cjk = sum(
1
for ch in text
if "\u4e00" <= ch <= "\u9fff" # CJK Unified Ideographs
or "\u3040" <= ch <= "\u30ff" # Hiragana + Katakana
or "\uac00" <= ch <= "\ud7a3" # Hangul syllables
)
return (len(text) - cjk) // 4 + cjk // 2
def _count_tokens(text: str, encoding_name: str = "cl100k_base", *, use_tiktoken: bool = True) -> int:
"""Count tokens in text using tiktoken.
Args:
text: The text to count tokens for.
encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5).
use_tiktoken: When ``False``, skip tiktoken entirely and use the
network-free character-based estimate. This guarantees no BPE
download is attempted (see ``memory.token_counting`` config).
Returns:
The number of tokens in the text.
"""
if not use_tiktoken:
return _char_based_token_estimate(text)
encoding = _get_tiktoken_encoding(encoding_name)
if encoding is None:
# Fallback to character-based estimation if tiktoken is not available
# or the encoding failed to load.
return len(text) // 4
# Fallback to CJK-aware character estimation if tiktoken is not
# available or the encoding failed to load.
return _char_based_token_estimate(text)
try:
return len(encoding.encode(text))
except Exception:
# Fallback to character-based estimation on error
return len(text) // 4
# Fallback to CJK-aware character estimation on error.
return _char_based_token_estimate(text)
def warm_tiktoken_cache() -> bool:
@@ -248,12 +316,15 @@ def _coerce_confidence(value: Any, default: float = 0.0) -> float:
return max(0.0, min(1.0, confidence))
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str:
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000, *, use_tiktoken: bool = True) -> str:
"""Format memory data for injection into system prompt.
Args:
memory_data: The memory data dictionary.
max_tokens: Maximum tokens to use (counted via tiktoken for accuracy).
use_tiktoken: When ``False``, all token counting uses the network-free
character-based estimate instead of tiktoken (see
``memory.token_counting`` config). Defaults to ``True``.
Returns:
Formatted memory string for system prompt injection.
@@ -315,10 +386,10 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
# Compute token count for existing sections once, then account
# incrementally for each fact line to avoid full-string re-tokenization.
base_text = "\n\n".join(sections)
base_tokens = _count_tokens(base_text) if base_text else 0
base_tokens = _count_tokens(base_text, use_tiktoken=use_tiktoken) if base_text else 0
# Account for the separator between existing sections and the facts section.
facts_header = "Facts:\n"
separator_tokens = _count_tokens("\n\n" + facts_header) if base_text else _count_tokens(facts_header)
separator_tokens = _count_tokens("\n\n" + facts_header, use_tiktoken=use_tiktoken) if base_text else _count_tokens(facts_header, use_tiktoken=use_tiktoken)
running_tokens = base_tokens + separator_tokens
fact_lines: list[str] = []
@@ -339,7 +410,7 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
# Each additional line is preceded by a newline (except the first).
line_text = ("\n" + line) if fact_lines else line
line_tokens = _count_tokens(line_text)
line_tokens = _count_tokens(line_text, use_tiktoken=use_tiktoken)
if running_tokens + line_tokens <= max_tokens:
fact_lines.append(line)
@@ -355,8 +426,9 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
result = "\n\n".join(sections)
# Use accurate token counting with tiktoken
token_count = _count_tokens(result)
# Use accurate token counting with tiktoken (or the char-based estimate
# when use_tiktoken is False).
token_count = _count_tokens(result, use_tiktoken=use_tiktoken)
if token_count > max_tokens:
# Truncate to fit within token limit
# Estimate characters to remove based on token ratio
@@ -1141,6 +1141,7 @@ class DeerFlowClient:
"fact_confidence_threshold": config.fact_confidence_threshold,
"injection_enabled": config.injection_enabled,
"max_injection_tokens": config.max_injection_tokens,
"token_counting": config.token_counting,
}
def get_memory_status(self) -> dict:
@@ -470,14 +470,32 @@ class AioSandboxProvider(SandboxProvider):
existing_id = self._thread_sandboxes[thread_id]
if existing_id in self._sandboxes:
info = self._sandbox_infos.get(existing_id)
else:
del self._thread_sandboxes[thread_id]
return None
alive = self._check_tracked_sandbox_alive(existing_id, info) if info is not None else True
if alive is False:
self._drop_unhealthy_sandbox(
existing_id,
"in-process cache failed health check",
expected_info=info,
)
return None
with self._lock:
if self._thread_sandboxes.get(thread_id) != existing_id:
return None
if existing_id not in self._sandboxes:
self._thread_sandboxes.pop(thread_id, None)
return None
suffix = " (post-lock check)" if post_lock else ""
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}{suffix}")
self._last_activity[existing_id] = time.time()
return existing_id
del self._thread_sandboxes[thread_id]
return None
def _reclaim_warm_pool_sandbox(self, thread_id: str | None, sandbox_id: str, *, post_lock: bool = False) -> str | None:
"""Promote a warm-pool sandbox back to active tracking if available."""
if thread_id is None:
@@ -487,7 +505,22 @@ class AioSandboxProvider(SandboxProvider):
if sandbox_id not in self._warm_pool:
return None
info, _ = self._warm_pool.pop(sandbox_id)
info, _ = self._warm_pool[sandbox_id]
alive = self._check_tracked_sandbox_alive(sandbox_id, info)
if alive is False:
self._drop_unhealthy_sandbox(
sandbox_id,
"warm-pool cache failed health check",
expected_info=info,
)
return None
with self._lock:
warm_item = self._warm_pool.pop(sandbox_id, None)
if warm_item is None:
return None
info, _ = warm_item
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
self._sandboxes[sandbox_id] = sandbox
self._sandbox_infos[sandbox_id] = info
@@ -527,6 +560,70 @@ class AioSandboxProvider(SandboxProvider):
logger.info(f"Created sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}")
return sandbox_id
def _check_tracked_sandbox_alive(self, sandbox_id: str, info: SandboxInfo) -> bool | None:
"""Return whether a tracked sandbox appears alive, or None if unknown."""
try:
return self._backend.is_alive(info)
except Exception as e:
logger.warning(f"Failed to check sandbox {sandbox_id} health: {e}")
return None
def _remove_tracked_sandbox(
self,
sandbox_id: str,
*,
expected_info: SandboxInfo | None = None,
) -> tuple[Sandbox | None, SandboxInfo | None, bool]:
"""Remove a sandbox from in-process tracking maps.
When expected_info is provided, removal only happens if the currently
tracked active or warm-pool entry is the exact info object that was
checked. This prevents a stale health-check result from deleting a
freshly recreated sandbox with the same deterministic id.
"""
thread_ids_to_remove: list[str] = []
with self._lock:
active_info = self._sandbox_infos.get(sandbox_id)
warm_item = self._warm_pool.get(sandbox_id)
warm_info = warm_item[0] if warm_item is not None else None
if expected_info is not None and active_info is not expected_info and warm_info is not expected_info:
return None, None, False
sandbox = self._sandboxes.pop(sandbox_id, None)
info = self._sandbox_infos.pop(sandbox_id, None)
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
for tid in thread_ids_to_remove:
del self._thread_sandboxes[tid]
self._last_activity.pop(sandbox_id, None)
if info is None and sandbox_id in self._warm_pool:
info, _ = self._warm_pool.pop(sandbox_id)
else:
self._warm_pool.pop(sandbox_id, None)
return sandbox, info, True
def _drop_unhealthy_sandbox(self, sandbox_id: str, reason: str, *, expected_info: SandboxInfo | None = None) -> None:
"""Remove and destroy a sandbox after a definitive failed health check."""
sandbox, info, removed = self._remove_tracked_sandbox(sandbox_id, expected_info=expected_info)
if not removed:
logger.info(f"Skipped dropping sandbox {sandbox_id}: tracked info changed after health check")
return
if sandbox is not None:
try:
sandbox.close()
except Exception as e:
logger.warning(f"Error closing unhealthy sandbox {sandbox_id}: {e}")
if info is not None:
try:
self._backend.destroy(info)
except Exception as e:
logger.warning(f"Error destroying unhealthy sandbox {sandbox_id}: {e}")
logger.warning(f"Dropped unhealthy sandbox {sandbox_id}: {reason}")
def _replica_count(self) -> tuple[int, int]:
"""Return configured replicas and currently tracked sandbox count."""
replicas = self._config.get("replicas", DEFAULT_REPLICAS)
@@ -617,7 +714,7 @@ class AioSandboxProvider(SandboxProvider):
async def _acquire_internal_async(self, thread_id: str | None) -> str:
"""Async counterpart to ``_acquire_internal``."""
cached_id = self._reuse_in_process_sandbox(thread_id)
cached_id = await asyncio.to_thread(self._reuse_in_process_sandbox, thread_id)
if cached_id is not None:
return cached_id
@@ -625,7 +722,7 @@ class AioSandboxProvider(SandboxProvider):
sandbox_id = self._sandbox_id_for_thread(thread_id)
# ── Layer 1.5: Warm pool (container still running, no cold-start) ──
reclaimed_id = self._reclaim_warm_pool_sandbox(thread_id, sandbox_id)
reclaimed_id = await asyncio.to_thread(self._reclaim_warm_pool_sandbox, thread_id, sandbox_id)
if reclaimed_id is not None:
return reclaimed_id
@@ -681,7 +778,7 @@ class AioSandboxProvider(SandboxProvider):
locked = True
# Re-check in-process caches under the file lock in case another
# thread in this process won the race while we were waiting.
cached_id = self._recheck_cached_sandbox(thread_id, sandbox_id)
cached_id = await asyncio.to_thread(self._recheck_cached_sandbox, thread_id, sandbox_id)
if cached_id is not None:
return cached_id
@@ -837,22 +934,7 @@ class AioSandboxProvider(SandboxProvider):
Args:
sandbox_id: The ID of the sandbox to destroy.
"""
info = None
sandbox = None
thread_ids_to_remove: list[str] = []
with self._lock:
sandbox = self._sandboxes.pop(sandbox_id, None)
info = self._sandbox_infos.pop(sandbox_id, None)
thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id]
for tid in thread_ids_to_remove:
del self._thread_sandboxes[tid]
self._last_activity.pop(sandbox_id, None)
# Also pull from warm pool if it was parked there
if info is None and sandbox_id in self._warm_pool:
info, _ = self._warm_pool.pop(sandbox_id)
else:
self._warm_pool.pop(sandbox_id, None)
sandbox, info, _ = self._remove_tracked_sandbox(sandbox_id)
if sandbox is not None:
# Defense-in-depth: close() already swallows its own errors; this
@@ -169,6 +169,24 @@ def _resolve_docker_bind_host(sandbox_host: str | None = None, bind_host: str |
return "0.0.0.0"
def _is_no_such_container_error(stderr: str, container_name: str) -> bool:
"""Return True only when stderr definitively says the container does not exist.
Docker reports "No such object" / "No such container". Apple Container
reports a generic "not found", so that phrase is only trusted when the
message also names the inspected container (or refers to a
container/object); transient failures whose text happens to contain
"not found" (e.g. "command not found", "context not found") must stay on
the raise path instead of being misread as a dead container.
"""
message = stderr.lower()
if "no such object" in message or "no such container" in message:
return True
if "not found" not in message:
return False
return container_name.lower() in message or "container" in message or "object" in message
class LocalContainerBackend(SandboxBackend):
"""Backend that manages sandbox containers locally using Docker or Apple Container.
@@ -335,11 +353,21 @@ class LocalContainerBackend(SandboxBackend):
sandbox_id: The deterministic sandbox ID (determines container name).
Returns:
SandboxInfo if container found and healthy, None otherwise.
SandboxInfo if container found and healthy, None otherwise. A
failed runtime check (e.g. transient daemon error) also returns
None discovery must not adopt a container it cannot verify, and
falling through to create keeps acquire recoverable instead of
hard-failing on a hiccup.
"""
container_name = f"{self._container_prefix}-{sandbox_id}"
if not self._is_container_running(container_name):
try:
running = self._is_container_running(container_name)
except RuntimeError as e:
logger.warning(f"Could not verify container {container_name} during discovery; not adopting it: {e}")
return None
if not running:
return None
port = self._get_container_port(container_name)
@@ -582,6 +610,13 @@ class LocalContainerBackend(SandboxBackend):
This enables cross-process container discovery any process can detect
containers started by another process via the deterministic container name.
Raises:
RuntimeError: If the container runtime cannot answer the inspect
query. A failed check is intentionally distinct from a
definitive "container does not exist" result so callers do not
destroy healthy containers during transient Docker/Container
daemon failures.
"""
try:
result = subprocess.run(
@@ -590,9 +625,14 @@ class LocalContainerBackend(SandboxBackend):
text=True,
timeout=5,
)
return result.returncode == 0 and result.stdout.strip().lower() == "true"
except (subprocess.CalledProcessError, subprocess.TimeoutExpired):
except subprocess.TimeoutExpired as exc:
raise RuntimeError(f"Timed out checking container {container_name}") from exc
if result.returncode == 0:
return result.stdout.strip().lower() == "true"
if _is_no_such_container_error(result.stderr, container_name):
return False
raise RuntimeError(f"Failed to inspect container {container_name}: {result.stderr.strip()}")
def _get_container_port(self, container_name: str) -> int | None:
"""Get the host port of a running container.
@@ -176,12 +176,16 @@ class RemoteSandboxBackend(SandboxBackend):
f"{self._provisioner_url}/api/sandboxes/{sandbox_id}",
timeout=10,
)
if resp.ok:
except requests.RequestException as exc:
raise RuntimeError(f"Provisioner health check failed for {sandbox_id}: {exc}") from exc
if resp.status_code == 404:
return False
if not resp.ok:
raise RuntimeError(f"Provisioner health check failed for {sandbox_id}: HTTP {resp.status_code} {resp.text}")
data = resp.json()
return data.get("status") == "Running"
return False
except requests.RequestException:
return False
def _provisioner_discover(self, sandbox_id: str) -> SandboxInfo | None:
"""GET /api/sandboxes/{sandbox_id} → discover existing sandbox."""
@@ -67,11 +67,13 @@ def resolve_agent_dir(name: str, *, user_id: str | None = None) -> Path:
paths = get_paths()
effective_user = user_id or get_effective_user_id()
user_path = paths.user_agent_dir(effective_user, name)
if user_path.exists():
# Require config.yaml to confirm this is a genuine agent directory,
# not a leftover from memory/storage writes (see #3390).
if user_path.exists() and (user_path / "config.yaml").exists():
return user_path
legacy_path = paths.agent_dir(name)
if legacy_path.exists():
if legacy_path.exists() and (legacy_path / "config.yaml").exists():
return legacy_path
return user_path
@@ -11,6 +11,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
from deerflow.config.acp_config import ACPAgentConfig, load_acp_config_from_dict
from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
from deerflow.config.database_config import DatabaseConfig
from deerflow.config.extensions_config import ExtensionsConfig
@@ -116,6 +117,7 @@ class AppConfig(BaseModel):
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
channel_connections: ChannelConnectionsConfig = Field(default_factory=ChannelConnectionsConfig, description="User-facing IM channel connection configuration")
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration")
model_config = ConfigDict(extra="allow")
@@ -0,0 +1,61 @@
"""Configuration for user-owned IM channel connections."""
from __future__ import annotations
from pydantic import BaseModel, Field
class SlackChannelConnectionConfig(BaseModel):
enabled: bool = False
@property
def configured(self) -> bool:
return True
class TelegramChannelConnectionConfig(BaseModel):
enabled: bool = False
bot_username: str = ""
@property
def configured(self) -> bool:
return bool(self.bot_username)
class DiscordChannelConnectionConfig(BaseModel):
enabled: bool = False
@property
def configured(self) -> bool:
return True
class BindingCodeChannelConnectionConfig(BaseModel):
enabled: bool = False
@property
def configured(self) -> bool:
return True
class ChannelConnectionsConfig(BaseModel):
"""Top-level config for browser-connectable IM channels."""
enabled: bool = False
slack: SlackChannelConnectionConfig = Field(default_factory=SlackChannelConnectionConfig)
telegram: TelegramChannelConnectionConfig = Field(default_factory=TelegramChannelConnectionConfig)
discord: DiscordChannelConnectionConfig = Field(default_factory=DiscordChannelConnectionConfig)
feishu: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
dingtalk: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
wechat: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
wecom: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
def provider_status(self, provider: str) -> dict[str, bool]:
config = getattr(self, provider, None)
if config is None:
return {"enabled": False, "configured": False}
enabled = bool(config.enabled)
return {
"enabled": enabled,
"configured": enabled and bool(config.configured),
}
@@ -1,5 +1,7 @@
"""Configuration for memory mechanism."""
from typing import Literal
from pydantic import BaseModel, Field
@@ -60,6 +62,17 @@ class MemoryConfig(BaseModel):
le=8000,
description="Maximum tokens to use for memory injection",
)
token_counting: Literal["tiktoken", "char"] = Field(
default="tiktoken",
description=(
"Token counting strategy for memory-injection budgeting. "
"'tiktoken' is accurate but the encoding's BPE data may be "
"downloaded from a public network endpoint on first use, which "
"can block for a long time in network-restricted environments "
"(see issue #3402/#3429). 'char' uses a network-free "
"CJK-aware character-based estimate and never touches tiktoken."
),
)
# Global configuration instance
@@ -1,4 +1,5 @@
import hashlib
import logging
import os
import re
import shutil
@@ -13,6 +14,9 @@ _SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
_SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
_UNSAFE_USER_ID_CHAR_RE = re.compile(r"[^A-Za-z0-9_\-]")
_SAFE_USER_ID_DIGEST_HEX_LEN = 16
_SAFE_USER_ID_DIGEST_RE = re.compile(r"^[0-9a-f]{16}$")
logger = logging.getLogger(__name__)
def _default_local_base_dir() -> Path:
@@ -47,10 +51,17 @@ def make_safe_user_id(raw: str) -> str:
sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw)
if sanitized == raw:
return raw
digest = hashlib.sha1(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN]
digest = hashlib.sha256(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN]
return f"{sanitized}-{digest}"
def _looks_like_digested_user_dir(name: str, prefix: str) -> bool:
if not name.startswith(prefix):
return False
suffix = name[len(prefix) :]
return bool(_SAFE_USER_ID_DIGEST_RE.match(suffix))
def _join_host_path(base: str, *parts: str) -> str:
"""Join host filesystem path segments while preserving native style.
@@ -172,6 +183,45 @@ class Paths:
"""Directory for a specific user: `{base_dir}/users/{user_id}/`."""
return self.base_dir / "users" / _validate_user_id(user_id)
def prepare_user_dir_for_raw_id(self, raw_user_id: str) -> str:
"""Return the safe user ID and migrate a unique legacy unsafe-id bucket.
A previous branch revision used a weak digest for unsafe external user IDs.
New IDs use SHA-256, but if exactly one old-style bucket with the same
sanitized prefix already exists, move it to the current bucket name so
existing local memory/files/threads remain visible.
"""
safe_user_id = make_safe_user_id(raw_user_id)
sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw_user_id)
if safe_user_id == raw_user_id:
return safe_user_id
users_dir = self.base_dir / "users"
target_dir = users_dir / safe_user_id
if target_dir.exists() or not users_dir.exists():
return safe_user_id
legacy_prefix = f"{sanitized}-"
try:
legacy_candidates = [candidate for candidate in users_dir.iterdir() if candidate.is_dir() and candidate.name != safe_user_id and _looks_like_digested_user_dir(candidate.name, legacy_prefix)]
except OSError:
logger.exception("Failed to inspect user directories for legacy unsafe-id migration")
return safe_user_id
if not legacy_candidates:
return safe_user_id
if len(legacy_candidates) > 1:
logger.warning("Multiple legacy unsafe-id user directories matched; skipping automatic migration")
return safe_user_id
try:
legacy_candidates[0].rename(target_dir)
logger.info("Migrated legacy unsafe-id user directory to the current digest format")
except OSError:
logger.exception("Failed to migrate legacy unsafe-id user directory")
return safe_user_id
def user_memory_file(self, user_id: str) -> Path:
"""Per-user memory file: `{base_dir}/users/{user_id}/memory.json`."""
return self.user_dir(user_id) / "memory.json"
@@ -4,7 +4,20 @@ from pydantic import BaseModel, ConfigDict, Field
class VolumeMountConfig(BaseModel):
"""Configuration for a volume mount."""
host_path: str = Field(..., description="Path on the host machine")
host_path: str = Field(
...,
description=(
"Source path for the mount. Resolution depends on the active provider: "
"``LocalSandboxProvider`` checks this path from the gateway process — in "
"``make dev`` that is the host machine, but in Docker deployments "
"(``make up`` / docker-compose) it is the path *inside* the "
"``deer-flow-gateway`` container, so the host directory must also be "
"bind-mounted into the gateway service for the mount to take effect. "
"``AioSandboxProvider`` (DooD) passes this value straight to ``docker -v`` "
"for the sandbox container, where it is resolved by the host Docker daemon "
"from the host machine's perspective."
),
)
container_path: str = Field(..., description="Path inside the container")
read_only: bool = Field(default=False, description="Whether the mount is read-only")
@@ -0,0 +1,21 @@
"""User-owned IM channel connection persistence."""
from deerflow.persistence.channel_connections.model import (
ChannelConnectionRow,
ChannelConversationRow,
ChannelCredentialRow,
ChannelOAuthStateRow,
)
from deerflow.persistence.channel_connections.sql import (
ChannelConnectionRepository,
ChannelCredentialCipher,
)
__all__ = [
"ChannelConnectionRepository",
"ChannelConnectionRow",
"ChannelConversationRow",
"ChannelCredentialCipher",
"ChannelCredentialRow",
"ChannelOAuthStateRow",
]
@@ -0,0 +1,111 @@
"""ORM models for user-owned IM channel connections."""
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import JSON, DateTime, ForeignKey, Index, Integer, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
def _utc_now() -> datetime:
return datetime.now(UTC)
class ChannelConnectionRow(Base):
__tablename__ = "channel_connections"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
status: Mapped[str] = mapped_column(String(32), nullable=False, default="connected")
external_account_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
external_account_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
workspace_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
workspace_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
bot_user_id: Mapped[str | None] = mapped_column(String(128), nullable=True)
scopes_json: Mapped[list] = mapped_column(JSON, default=list)
capabilities_json: Mapped[dict] = mapped_column(JSON, default=dict)
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now)
last_seen_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
last_error_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
__table_args__ = (
UniqueConstraint(
"owner_user_id",
"provider",
"external_account_id",
"workspace_id",
name="uq_channel_connection_owner_provider_identity",
),
Index("idx_channel_connections_event_lookup", "provider", "workspace_id", "bot_user_id"),
)
class ChannelCredentialRow(Base):
__tablename__ = "channel_credentials"
connection_id: Mapped[str] = mapped_column(
String(64),
ForeignKey("channel_connections.id", ondelete="CASCADE"),
primary_key=True,
)
encrypted_access_token: Mapped[str | None] = mapped_column(Text, nullable=True)
encrypted_refresh_token: Mapped[str | None] = mapped_column(Text, nullable=True)
token_type: Mapped[str | None] = mapped_column(String(32), nullable=True)
expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
refresh_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
encrypted_extra_json: Mapped[str | None] = mapped_column(Text, nullable=True)
version: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now)
class ChannelOAuthStateRow(Base):
__tablename__ = "channel_oauth_states"
state_hash: Mapped[str] = mapped_column(String(128), primary_key=True)
owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
code_verifier_encrypted: Mapped[str | None] = mapped_column(Text, nullable=True)
nonce_hash: Mapped[str | None] = mapped_column(String(128), nullable=True)
redirect_after: Mapped[str | None] = mapped_column(Text, nullable=True)
requested_scopes_json: Mapped[list] = mapped_column(JSON, default=list)
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
consumed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now)
class ChannelConversationRow(Base):
__tablename__ = "channel_conversations"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
connection_id: Mapped[str] = mapped_column(
String(64),
ForeignKey("channel_connections.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
external_conversation_id: Mapped[str] = mapped_column(String(128), nullable=False)
external_topic_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now)
__table_args__ = (
UniqueConstraint(
"connection_id",
"external_conversation_id",
"external_topic_id",
name="uq_channel_conversation_connection_external",
),
)
@@ -0,0 +1,349 @@
"""SQL repository for user-owned IM channel connections."""
from __future__ import annotations
import base64
import hashlib
import json
import uuid
from datetime import UTC, datetime
from typing import Any
from cryptography.fernet import Fernet
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.channel_connections.model import (
ChannelConnectionRow,
ChannelConversationRow,
ChannelCredentialRow,
ChannelOAuthStateRow,
)
from deerflow.utils.time import coerce_iso
class ChannelCredentialCipher:
"""Encrypts provider credentials before they are persisted."""
def __init__(self, fernet: Fernet) -> None:
self._fernet = fernet
@classmethod
def from_key(cls, key: str) -> ChannelCredentialCipher:
digest = hashlib.sha256(key.encode("utf-8")).digest()
return cls(Fernet(base64.urlsafe_b64encode(digest)))
def encrypt_text(self, value: str | None) -> str | None:
if value is None:
return None
return "fernet:v1:" + self._fernet.encrypt(value.encode("utf-8")).decode("ascii")
def decrypt_text(self, value: str | None) -> str | None:
if value is None:
return None
token = value.removeprefix("fernet:v1:")
return self._fernet.decrypt(token.encode("ascii")).decode("utf-8")
class ChannelConnectionRepository:
"""Persistence facade for channel connections, credentials, and conversations."""
def __init__(
self,
session_factory: async_sessionmaker[AsyncSession],
*,
cipher: ChannelCredentialCipher | None = None,
) -> None:
self.session_factory = session_factory
self._cipher = cipher
async def close(self) -> None:
from deerflow.persistence.engine import close_engine
await close_engine()
@staticmethod
def _new_id() -> str:
return uuid.uuid4().hex
@staticmethod
def _normalize_optional_identity(value: str | None) -> str:
return value or ""
@staticmethod
def _coerce_datetime(value: datetime | None) -> datetime | None:
if value is None or value.tzinfo is not None:
return value
return value.replace(tzinfo=UTC)
def _encrypt_optional_secret(self, value: str | None) -> str | None:
if value is None:
return None
if self._cipher is None:
raise RuntimeError("channel connection encryption key is required")
return self._cipher.encrypt_text(value)
@staticmethod
def _connection_to_dict(row: ChannelConnectionRow) -> dict[str, Any]:
data = row.to_dict()
data["external_account_id"] = data["external_account_id"] or None
data["workspace_id"] = data["workspace_id"] or None
data["scopes"] = data.pop("scopes_json") or []
data["capabilities"] = data.pop("capabilities_json") or {}
data["metadata"] = data.pop("metadata_json") or {}
for key in ("created_at", "updated_at", "last_seen_at", "last_error_at"):
value = data.get(key)
if isinstance(value, datetime):
data[key] = coerce_iso(value)
return data
async def upsert_connection(
self,
*,
owner_user_id: str,
provider: str,
external_account_id: str | None = None,
external_account_name: str | None = None,
workspace_id: str | None = None,
workspace_name: str | None = None,
bot_user_id: str | None = None,
scopes: list[str] | None = None,
capabilities: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
status: str = "connected",
) -> dict[str, Any]:
external_account_id_value = self._normalize_optional_identity(external_account_id)
workspace_id_value = self._normalize_optional_identity(workspace_id)
async with self.session_factory() as session:
stmt = select(ChannelConnectionRow).where(
ChannelConnectionRow.owner_user_id == owner_user_id,
ChannelConnectionRow.provider == provider,
ChannelConnectionRow.external_account_id == external_account_id_value,
ChannelConnectionRow.workspace_id == workspace_id_value,
)
row = (await session.execute(stmt)).scalar_one_or_none()
if row is None:
row = ChannelConnectionRow(
id=self._new_id(),
owner_user_id=owner_user_id,
provider=provider,
external_account_id=external_account_id_value,
workspace_id=workspace_id_value,
)
session.add(row)
row.status = status
row.external_account_name = external_account_name
row.workspace_name = workspace_name
row.bot_user_id = bot_user_id
row.scopes_json = list(scopes or [])
row.capabilities_json = dict(capabilities or {})
row.metadata_json = dict(metadata or {})
await session.commit()
await session.refresh(row)
return self._connection_to_dict(row)
async def list_connections(self, owner_user_id: str) -> list[dict[str, Any]]:
async with self.session_factory() as session:
result = await session.execute(select(ChannelConnectionRow).where(ChannelConnectionRow.owner_user_id == owner_user_id).order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc()))
return [self._connection_to_dict(row) for row in result.scalars()]
async def disconnect_connection(self, *, connection_id: str, owner_user_id: str) -> bool:
async with self.session_factory() as session:
row = await session.get(ChannelConnectionRow, connection_id)
if row is None or row.owner_user_id != owner_user_id:
return False
row.status = "revoked"
credential = await session.get(ChannelCredentialRow, connection_id)
if credential is not None:
await session.delete(credential)
await session.commit()
return True
async def store_credentials(
self,
connection_id: str,
*,
access_token: str | None,
refresh_token: str | None = None,
token_type: str | None = None,
expires_at: datetime | None = None,
refresh_expires_at: datetime | None = None,
extra: dict[str, Any] | None = None,
) -> None:
if self._cipher is None:
raise RuntimeError("channel connection encryption key is required")
async with self.session_factory() as session:
row = await session.get(ChannelCredentialRow, connection_id)
if row is None:
row = ChannelCredentialRow(connection_id=connection_id)
session.add(row)
row.encrypted_access_token = self._cipher.encrypt_text(access_token)
row.encrypted_refresh_token = self._cipher.encrypt_text(refresh_token)
row.token_type = token_type
row.expires_at = expires_at
row.refresh_expires_at = refresh_expires_at
row.encrypted_extra_json = self._cipher.encrypt_text(json.dumps(extra or {}, ensure_ascii=False))
row.version = (row.version or 0) + 1
await session.commit()
async def get_credentials(self, connection_id: str) -> dict[str, Any] | None:
if self._cipher is None:
return None
async with self.session_factory() as session:
row = await session.get(ChannelCredentialRow, connection_id)
if row is None:
return None
extra_raw = self._cipher.decrypt_text(row.encrypted_extra_json)
return {
"connection_id": row.connection_id,
"access_token": self._cipher.decrypt_text(row.encrypted_access_token),
"refresh_token": self._cipher.decrypt_text(row.encrypted_refresh_token),
"token_type": row.token_type,
"expires_at": self._coerce_datetime(row.expires_at),
"refresh_expires_at": self._coerce_datetime(row.refresh_expires_at),
"extra": json.loads(extra_raw) if extra_raw else {},
}
@staticmethod
def hash_state(state: str) -> str:
return hashlib.sha256(state.encode("utf-8")).hexdigest()
async def create_oauth_state(
self,
*,
owner_user_id: str,
provider: str,
state: str,
expires_at: datetime,
code_verifier: str | None = None,
nonce_hash: str | None = None,
redirect_after: str | None = None,
requested_scopes: list[str] | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
row = ChannelOAuthStateRow(
state_hash=self.hash_state(state),
owner_user_id=owner_user_id,
provider=provider,
code_verifier_encrypted=self._encrypt_optional_secret(code_verifier),
nonce_hash=nonce_hash,
redirect_after=redirect_after,
requested_scopes_json=list(requested_scopes or []),
metadata_json=dict(metadata or {}),
expires_at=expires_at,
)
async with self.session_factory() as session:
session.add(row)
await session.commit()
async def count_oauth_states(self, *, owner_user_id: str, provider: str) -> int:
async with self.session_factory() as session:
result = await session.execute(
select(ChannelOAuthStateRow).where(
ChannelOAuthStateRow.owner_user_id == owner_user_id,
ChannelOAuthStateRow.provider == provider,
)
)
return len(list(result.scalars()))
async def consume_oauth_state(
self,
*,
provider: str,
state: str,
now: datetime | None = None,
) -> dict[str, Any] | None:
current_time = now or datetime.now(UTC)
async with self.session_factory() as session:
await session.execute(delete(ChannelOAuthStateRow).where(ChannelOAuthStateRow.expires_at < current_time))
row = await session.get(ChannelOAuthStateRow, self.hash_state(state))
if row is None or row.provider != provider or row.consumed_at is not None:
await session.commit()
return None
expires_at = self._coerce_datetime(row.expires_at)
if expires_at is not None and expires_at < current_time:
await session.commit()
return None
row.consumed_at = current_time
await session.commit()
return {
"owner_user_id": row.owner_user_id,
"provider": row.provider,
"requested_scopes": row.requested_scopes_json or [],
"metadata": row.metadata_json or {},
"redirect_after": row.redirect_after,
}
async def find_connection_by_external_identity(
self,
*,
provider: str,
external_account_id: str,
workspace_id: str | None = None,
) -> dict[str, Any] | None:
async with self.session_factory() as session:
result = await session.execute(
select(ChannelConnectionRow)
.where(
ChannelConnectionRow.provider == provider,
ChannelConnectionRow.external_account_id == self._normalize_optional_identity(external_account_id),
ChannelConnectionRow.workspace_id == self._normalize_optional_identity(workspace_id),
ChannelConnectionRow.status == "connected",
)
.order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc())
.limit(1)
)
row = result.scalar_one_or_none()
return self._connection_to_dict(row) if row is not None else None
async def set_thread_id(
self,
*,
connection_id: str,
owner_user_id: str,
provider: str,
external_conversation_id: str,
thread_id: str,
external_topic_id: str | None = None,
) -> None:
topic_id = external_topic_id or ""
async with self.session_factory() as session:
stmt = select(ChannelConversationRow).where(
ChannelConversationRow.connection_id == connection_id,
ChannelConversationRow.external_conversation_id == external_conversation_id,
ChannelConversationRow.external_topic_id == topic_id,
)
row = (await session.execute(stmt)).scalar_one_or_none()
if row is None:
row = ChannelConversationRow(
id=self._new_id(),
connection_id=connection_id,
owner_user_id=owner_user_id,
provider=provider,
external_conversation_id=external_conversation_id,
external_topic_id=topic_id,
thread_id=thread_id,
)
session.add(row)
else:
row.thread_id = thread_id
row.owner_user_id = owner_user_id
row.provider = provider
await session.commit()
async def get_thread_id(
self,
connection_id: str,
external_conversation_id: str,
external_topic_id: str | None = None,
) -> str | None:
async with self.session_factory() as session:
stmt = select(ChannelConversationRow.thread_id).where(
ChannelConversationRow.connection_id == connection_id,
ChannelConversationRow.external_conversation_id == external_conversation_id,
ChannelConversationRow.external_topic_id == (external_topic_id or ""),
)
return (await session.execute(stmt)).scalar_one_or_none()
@@ -14,10 +14,26 @@ its storage implementation lives in ``deerflow.runtime.events.store.db`` and
there is no matching entity directory.
"""
from deerflow.persistence.channel_connections.model import (
ChannelConnectionRow,
ChannelConversationRow,
ChannelCredentialRow,
ChannelOAuthStateRow,
)
from deerflow.persistence.feedback.model import FeedbackRow
from deerflow.persistence.models.run_event import RunEventRow
from deerflow.persistence.run.model import RunRow
from deerflow.persistence.thread_meta.model import ThreadMetaRow
from deerflow.persistence.user.model import UserRow
__all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow", "UserRow"]
__all__ = [
"ChannelConnectionRow",
"ChannelConversationRow",
"ChannelCredentialRow",
"ChannelOAuthStateRow",
"FeedbackRow",
"RunEventRow",
"RunRow",
"ThreadMetaRow",
"UserRow",
]
@@ -71,6 +71,15 @@ class ThreadMetaStore(abc.ABC):
"""
pass
@abc.abstractmethod
async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
"""Move a thread metadata row to a new owner.
Intended for trusted internal repair/migration paths. No-op if the
row does not exist or the caller fails the owner check.
"""
pass
@abc.abstractmethod
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
"""Check if ``user_id`` has access to ``thread_id``."""
@@ -127,6 +127,14 @@ class MemoryThreadMetaStore(ThreadMetaStore):
record["updated_at"] = now_iso()
await self._store.aput(THREADS_NS, thread_id, record)
async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_owner")
if record is None:
return
record["user_id"] = owner_user_id
record["updated_at"] = now_iso()
await self._store.aput(THREADS_NS, thread_id, record)
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete")
if record is None:
@@ -211,6 +211,21 @@ class ThreadMetaRepository(ThreadMetaStore):
row.updated_at = datetime.now(UTC)
await session.commit()
async def update_owner(
self,
thread_id: str,
owner_user_id: str,
*,
user_id: str | None | _AutoSentinel = AUTO,
) -> None:
"""Move a thread metadata row to ``owner_user_id``."""
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_owner")
async with self._sf() as session:
if not await self._check_ownership(session, thread_id, resolved_user_id):
return
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(user_id=owner_user_id, updated_at=datetime.now(UTC)))
await session.commit()
async def delete(
self,
thread_id: str,
@@ -164,7 +164,18 @@ class RunJournal(BaseCallbackHandler):
metadata={"caller": caller, **(metadata or {})},
)
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
def on_chain_end(
self,
outputs: Any,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
) -> None:
# Nested chain ends fire for internal graph nodes; only the root chain
# represents the user-visible run lifecycle.
if parent_run_id is not None:
return
self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"})
self._flush_sync()
@@ -83,6 +83,7 @@ class RunRecord:
multitask_strategy: str = "reject"
metadata: dict = field(default_factory=dict)
kwargs: dict = field(default_factory=dict)
user_id: str | None = None
created_at: str = ""
updated_at: str = ""
task: asyncio.Task | None = field(default=None, repr=False)
@@ -124,7 +125,7 @@ class RunManager:
@staticmethod
def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]:
return {
payload = {
"thread_id": record.thread_id,
"assistant_id": record.assistant_id,
"status": record.status.value,
@@ -135,6 +136,9 @@ class RunManager:
"created_at": record.created_at,
"model_name": record.model_name,
}
if record.user_id is not None:
payload["user_id"] = record.user_id
return payload
async def _call_store_with_retry(
self,
@@ -241,6 +245,7 @@ class RunManager:
kwargs=row.get("kwargs") or {},
created_at=row.get("created_at") or "",
updated_at=row.get("updated_at") or "",
user_id=row.get("user_id"),
error=row.get("error"),
model_name=row.get("model_name"),
store_only=True,
@@ -320,6 +325,7 @@ class RunManager:
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
user_id: str | None = None,
) -> RunRecord:
"""Create a new pending run and register it."""
run_id = str(uuid.uuid4())
@@ -333,6 +339,7 @@ class RunManager:
multitask_strategy=multitask_strategy,
metadata=metadata or {},
kwargs=kwargs or {},
user_id=user_id,
created_at=now,
updated_at=now,
)
@@ -504,6 +511,7 @@ class RunManager:
kwargs: dict | None = None,
multitask_strategy: str = "reject",
model_name: str | None = None,
user_id: str | None = None,
) -> RunRecord:
"""Atomically check for inflight runs and create a new one.
@@ -546,6 +554,7 @@ class RunManager:
multitask_strategy=multitask_strategy,
metadata=metadata or {},
kwargs=kwargs or {},
user_id=user_id,
created_at=now,
updated_at=now,
model_name=model_name,
@@ -147,7 +147,17 @@ class LocalSandboxProvider(SandboxProvider):
mount.container_path,
)
continue
# Ensure the host path exists before adding mapping
# Ensure the host path exists before adding mapping.
#
# ``host_path`` is resolved against the filesystem of the
# process running this provider — for ``make dev`` that is
# the host machine, but for ``make up`` it is the
# ``deer-flow-gateway`` container, so any host path that
# isn't bind-mounted into the gateway image will be missing
# here. Skipping silently makes this a high-cost-to-debug
# silent failure (sandbox skill / tool reads an empty dir
# instead of the configured mount), so escalate to ERROR
# and include actionable guidance. See #3244.
if host_path.exists():
mappings.append(
PathMapping(
@@ -157,10 +167,16 @@ class LocalSandboxProvider(SandboxProvider):
)
)
else:
logger.warning(
"Mount host_path does not exist, skipping: %s -> %s",
logger.error(
"sandbox.mounts entry %s -> %s ignored: host_path %s does not exist from the "
"perspective of the gateway process. In Docker deployments (make up / docker-compose), "
"this path must also be bind-mounted into the gateway container — add a matching "
"volume entry under services.gateway.volumes in docker/docker-compose.yaml (and use "
"the in-container path here), or run in local mode (make dev) where the gateway sees "
"the host filesystem directly.",
mount.host_path,
mount.container_path,
mount.host_path,
)
except Exception as e:
# Log but don't fail if config loading fails
@@ -1,10 +1,15 @@
import asyncio
import logging
from collections.abc import Awaitable, Callable
from dataclasses import replace as dc_replace
from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import ToolMessage
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.runtime import Runtime
from langgraph.types import Command
from deerflow.agents.thread_state import SandboxState, ThreadDataState
from deerflow.sandbox import get_sandbox_provider
@@ -126,3 +131,87 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
# No sandbox to release
return await super().aafter_agent(state, runtime)
# ------------------------------------------------------------------
# Tool-call wrappers: persist lazily-acquired sandbox state into the
# graph state via Command(update=...).
#
# Background:
# ``ensure_sandbox_initialized*`` in ``deerflow.sandbox.tools`` mutates
# ``runtime.state["sandbox"]`` directly. That mutation is local to the
# current tool invocation and is NOT picked up by LangGraph's channel
# reducer, so subsequent graph steps (and downstream consumers such as
# ``ToolOutputBudgetMiddleware`` and the sub-agent ``task_tool``)
# cannot observe the sandbox id. Wrapping the tool call lets us detect
# a fresh lazy init by diffing the state snapshot before/after the
# handler and emit a proper state update via ``Command``.
# ------------------------------------------------------------------
@staticmethod
def _read_sandbox_id_from_state(state: object) -> str | None:
if not isinstance(state, dict):
return None
sandbox_state = state.get("sandbox")
if not isinstance(sandbox_state, dict):
return None
sandbox_id = sandbox_state.get("sandbox_id")
return sandbox_id if isinstance(sandbox_id, str) else None
@staticmethod
def _attach_sandbox_update(result: ToolMessage | Command, sandbox_id: str) -> ToolMessage | Command:
"""Wrap or merge ``result`` so that ``sandbox.sandbox_id`` is persisted.
- ``ToolMessage`` -> ``Command(update={"sandbox": ..., "messages": [msg]})``
- ``Command`` with dict update -> merge ``sandbox`` key, preserve all
existing fields (``messages``, ``goto``, ``graph``, ``resume``, ...).
- ``Command`` with non-dict / None update -> leave it untouched to
avoid silent data loss on unknown update shapes.
"""
sandbox_update = {"sandbox": {"sandbox_id": sandbox_id}}
if isinstance(result, ToolMessage):
return Command(update={**sandbox_update, "messages": [result]})
existing_update = result.update
if isinstance(existing_update, dict):
merged_update = {**existing_update, **sandbox_update}
return dc_replace(result, update=merged_update)
return result
@staticmethod
def _read_sandbox_id_from_request(request: ToolCallRequest) -> str | None:
"""Read sandbox_id from runtime.state (where ensure_sandbox_initialized writes)."""
runtime = request.runtime
if runtime is None or runtime.state is None:
return None
return SandboxMiddleware._read_sandbox_id_from_state(runtime.state)
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
prev_sandbox_id = self._read_sandbox_id_from_request(request)
result = handler(request)
if prev_sandbox_id is not None:
return result
curr_sandbox_id = self._read_sandbox_id_from_request(request)
if curr_sandbox_id is None:
return result
return self._attach_sandbox_update(result, curr_sandbox_id)
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
prev_sandbox_id = self._read_sandbox_id_from_request(request)
result = await handler(request)
if prev_sandbox_id is not None:
return result
curr_sandbox_id = self._read_sandbox_id_from_request(request)
if curr_sandbox_id is None:
return result
return self._attach_sandbox_update(result, curr_sandbox_id)
+1
View File
@@ -36,6 +36,7 @@ dependencies = [
"sqlalchemy[asyncio]>=2.0,<3.0",
"aiosqlite>=0.19",
"alembic>=1.13",
"cryptography>=43.0.0",
]
[project.optional-dependencies]
@@ -69,6 +69,7 @@
"keys": [
"artifacts",
"messages",
"sandbox",
"thread_data",
"title",
"viewed_images"
@@ -79,6 +80,7 @@
"keys": [
"artifacts",
"messages",
"sandbox",
"thread_data",
"title",
"viewed_images"
@@ -89,6 +91,7 @@
"keys": [
"artifacts",
"messages",
"sandbox",
"thread_data",
"title",
"viewed_images"
@@ -99,6 +102,7 @@
"keys": [
"artifacts",
"messages",
"sandbox",
"thread_data",
"title",
"viewed_images"
@@ -109,6 +113,7 @@
"keys": [
"artifacts",
"messages",
"sandbox",
"thread_data",
"title",
"viewed_images"
@@ -119,6 +124,7 @@
"keys": [
"artifacts",
"messages",
"sandbox",
"thread_data",
"title",
"viewed_images"
@@ -0,0 +1,251 @@
"""Connection binding tests for browser-connectable IM channels beyond Telegram/Slack/Discord."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
from app.channels.message_bus import InboundMessage, MessageBus
async def _make_repo(tmp_path, name: str):
from deerflow.persistence.channel_connections import ChannelConnectionRepository
from deerflow.persistence.engine import get_session_factory, init_engine
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / f'{name}.db'}", sqlite_dir=str(tmp_path))
return ChannelConnectionRepository(get_session_factory())
async def _seed_state(repo, provider: str, state: str, owner_user_id: str = "deerflow-user-1") -> None:
await repo.create_oauth_state(
owner_user_id=owner_user_id,
provider=provider,
state=state,
expires_at=datetime.now(UTC) + timedelta(minutes=5),
)
def test_feishu_connect_command_binds_identity(tmp_path):
import anyio
from app.channels.feishu import FeishuChannel
async def go():
repo = await _make_repo(tmp_path, "feishu")
state = "feishu-bind-code"
await _seed_state(repo, "feishu", state)
channel = FeishuChannel(
bus=MessageBus(),
config={"app_id": "app", "app_secret": "secret", "connection_repo": repo},
)
channel._reply_card = AsyncMock()
handled = await channel._bind_connection_from_connect_code(
message_id="om-message-1",
chat_id="oc-chat-1",
user_id="ou-user-1",
code=state,
)
connections = await repo.list_connections("deerflow-user-1")
assert handled is True
assert len(connections) == 1
assert connections[0]["provider"] == "feishu"
assert connections[0]["external_account_id"] == "ou-user-1"
assert connections[0]["workspace_id"] == "oc-chat-1"
channel._reply_card.assert_awaited_once_with("om-message-1", "Feishu connected to DeerFlow.")
await repo.close()
anyio.run(go)
def test_dingtalk_connect_command_binds_identity(tmp_path):
import anyio
from app.channels.dingtalk import _CONVERSATION_TYPE_GROUP, DingTalkChannel
async def go():
repo = await _make_repo(tmp_path, "dingtalk")
state = "dingtalk-bind-code"
await _seed_state(repo, "dingtalk", state)
channel = DingTalkChannel(
bus=MessageBus(),
config={"client_id": "client", "client_secret": "secret", "connection_repo": repo},
)
channel._send_connection_reply = AsyncMock()
handled = await channel._bind_connection_from_connect_code(
conversation_type=_CONVERSATION_TYPE_GROUP,
sender_staff_id="staff-user-1",
sender_nick="Alice",
conversation_id="cid-group-1",
code=state,
)
connections = await repo.list_connections("deerflow-user-1")
assert handled is True
assert len(connections) == 1
assert connections[0]["provider"] == "dingtalk"
assert connections[0]["external_account_id"] == "staff-user-1"
assert connections[0]["external_account_name"] == "Alice"
assert connections[0]["workspace_id"] == "cid-group-1"
channel._send_connection_reply.assert_awaited_once()
await repo.close()
anyio.run(go)
def test_wechat_connect_command_binds_identity(tmp_path):
import anyio
from app.channels.wechat import WechatChannel
async def go():
repo = await _make_repo(tmp_path, "wechat")
state = "wechat-bind-code"
await _seed_state(repo, "wechat", state)
channel = WechatChannel(
bus=MessageBus(),
config={"bot_token": "token", "connection_repo": repo},
)
channel._send_connection_reply = AsyncMock()
handled = await channel._bind_connection_from_connect_code(
chat_id="wx-user-1",
context_token="ctx-1",
code=state,
)
connections = await repo.list_connections("deerflow-user-1")
assert handled is True
assert len(connections) == 1
assert connections[0]["provider"] == "wechat"
assert connections[0]["external_account_id"] == "wx-user-1"
assert connections[0]["workspace_id"] == "wx-user-1"
channel._send_connection_reply.assert_awaited_once_with("wx-user-1", "ctx-1", "WeChat connected to DeerFlow.")
await repo.close()
anyio.run(go)
def test_wecom_connect_command_binds_identity(tmp_path):
import anyio
from app.channels.wecom import WeComChannel
async def go():
repo = await _make_repo(tmp_path, "wecom")
state = "wecom-bind-code"
await _seed_state(repo, "wecom", state)
channel = WeComChannel(
bus=MessageBus(),
config={"bot_id": "bot", "bot_secret": "secret", "connection_repo": repo},
)
channel._ws_client = MagicMock()
channel._ws_client.reply = AsyncMock()
frame = {"body": {"aibotid": "bot-1", "chattype": "single"}}
handled = await channel._bind_connection_from_connect_code(
frame=frame,
user_id="wecom-user-1",
code=state,
)
connections = await repo.list_connections("deerflow-user-1")
assert handled is True
assert len(connections) == 1
assert connections[0]["provider"] == "wecom"
assert connections[0]["external_account_id"] == "wecom-user-1"
assert connections[0]["workspace_id"] == "bot-1"
channel._ws_client.reply.assert_awaited_once_with(frame, {"msgtype": "text", "text": {"content": "WeCom connected to DeerFlow."}})
await repo.close()
anyio.run(go)
def test_additional_channels_attach_owner_identity(tmp_path):
import anyio
from app.channels.dingtalk import _CONVERSATION_TYPE_GROUP, DingTalkChannel
from app.channels.feishu import FeishuChannel
from app.channels.wechat import WechatChannel
from app.channels.wecom import WeComChannel
async def go():
repo = await _make_repo(tmp_path, "additional-identity")
await repo.upsert_connection(
owner_user_id="deerflow-user-1",
provider="feishu",
external_account_id="ou-user-1",
workspace_id="oc-chat-1",
)
await repo.upsert_connection(
owner_user_id="deerflow-user-1",
provider="dingtalk",
external_account_id="staff-user-1",
workspace_id="cid-group-1",
)
await repo.upsert_connection(
owner_user_id="deerflow-user-1",
provider="wechat",
external_account_id="wx-user-1",
workspace_id="wx-user-1",
)
await repo.upsert_connection(
owner_user_id="deerflow-user-1",
provider="wecom",
external_account_id="wecom-user-1",
workspace_id="bot-1",
)
cases = [
(
FeishuChannel(bus=MessageBus(), config={"connection_repo": repo}),
InboundMessage(channel_name="feishu", chat_id="oc-chat-1", user_id="ou-user-1", text="hello"),
),
(
DingTalkChannel(bus=MessageBus(), config={"connection_repo": repo}),
InboundMessage(
channel_name="dingtalk",
chat_id="cid-group-1",
user_id="staff-user-1",
text="hello",
metadata={
"conversation_type": _CONVERSATION_TYPE_GROUP,
"conversation_id": "cid-group-1",
},
),
),
(
WechatChannel(bus=MessageBus(), config={"connection_repo": repo}),
InboundMessage(channel_name="wechat", chat_id="wx-user-1", user_id="wx-user-1", text="hello"),
),
(
WeComChannel(bus=MessageBus(), config={"connection_repo": repo}),
InboundMessage(
channel_name="wecom",
chat_id="wecom-user-1",
user_id="wecom-user-1",
text="hello",
metadata={"aibotid": "bot-1"},
),
),
]
for channel, inbound in cases:
attached = await channel._attach_connection_identity(inbound)
assert attached.owner_user_id == "deerflow-user-1"
assert attached.connection_id
assert (
attached.workspace_id
== {
"feishu": "oc-chat-1",
"dingtalk": "cid-group-1",
"wechat": "wx-user-1",
"wecom": "bot-1",
}[channel.name]
)
await repo.close()
anyio.run(go)
@@ -1,7 +1,10 @@
import logging
import os
import subprocess
from types import SimpleNamespace
import pytest
from deerflow.community.aio_sandbox.local_backend import (
LocalContainerBackend,
_format_container_command_for_log,
@@ -234,3 +237,99 @@ def test_start_container_keeps_apple_container_port_format(monkeypatch):
captured_cmd = _capture_start_container_command(monkeypatch, backend, runtime="container")
assert captured_cmd[captured_cmd.index("-p") + 1] == "18080:8080"
def _backend_for_inspect_tests() -> LocalContainerBackend:
backend = LocalContainerBackend(
image="sandbox:latest",
base_port=8080,
container_prefix="sandbox",
config_mounts=[],
environment={},
)
backend._runtime = "docker"
return backend
def test_is_container_running_false_when_container_missing(monkeypatch):
backend = _backend_for_inspect_tests()
def fake_run(cmd, **kwargs):
return SimpleNamespace(stdout="", stderr="Error: No such object: sandbox-missing", returncode=1)
monkeypatch.setattr("subprocess.run", fake_run)
assert backend._is_container_running("sandbox-missing") is False
def test_is_container_running_raises_on_runtime_error(monkeypatch):
backend = _backend_for_inspect_tests()
def fake_run(cmd, **kwargs):
return SimpleNamespace(stdout="", stderr="Cannot connect to the Docker daemon", returncode=1)
monkeypatch.setattr("subprocess.run", fake_run)
with pytest.raises(RuntimeError, match="Failed to inspect container sandbox-busy"):
backend._is_container_running("sandbox-busy")
def test_is_container_running_raises_on_timeout(monkeypatch):
backend = _backend_for_inspect_tests()
def fake_run(cmd, **kwargs):
raise subprocess.TimeoutExpired(cmd=cmd, timeout=kwargs["timeout"])
monkeypatch.setattr("subprocess.run", fake_run)
with pytest.raises(RuntimeError, match="Timed out checking container sandbox-timeout"):
backend._is_container_running("sandbox-timeout")
def test_discover_returns_none_when_runtime_check_fails(monkeypatch):
"""A transient daemon error during discovery must fall through to create, not fail acquire."""
backend = _backend_for_inspect_tests()
def fake_run(cmd, **kwargs):
return SimpleNamespace(stdout="", stderr="Cannot connect to the Docker daemon", returncode=1)
monkeypatch.setattr("subprocess.run", fake_run)
assert backend.discover("sandbox-blip") is None
def test_discover_returns_none_when_runtime_check_times_out(monkeypatch):
"""An inspect timeout during discovery must not propagate out of discover()."""
backend = _backend_for_inspect_tests()
def fake_run(cmd, **kwargs):
raise subprocess.TimeoutExpired(cmd=cmd, timeout=kwargs["timeout"])
monkeypatch.setattr("subprocess.run", fake_run)
assert backend.discover("sandbox-timeout") is None
def test_is_container_running_false_on_apple_container_not_found(monkeypatch):
"""Apple Container's generic "not found" is trusted when it names the container."""
backend = _backend_for_inspect_tests()
def fake_run(cmd, **kwargs):
return SimpleNamespace(stdout="", stderr='Error: not found: "sandbox-apple"', returncode=1)
monkeypatch.setattr("subprocess.run", fake_run)
assert backend._is_container_running("sandbox-apple") is False
def test_is_container_running_raises_on_unrelated_not_found_error(monkeypatch):
"""Transient errors whose text contains "not found" must not be misread as a dead container."""
backend = _backend_for_inspect_tests()
def fake_run(cmd, **kwargs):
return SimpleNamespace(stdout="", stderr="Error: credential helper not found in $PATH", returncode=1)
monkeypatch.setattr("subprocess.run", fake_run)
with pytest.raises(RuntimeError, match="Failed to inspect container sandbox-busy"):
backend._is_container_running("sandbox-busy")
+152
View File
@@ -317,6 +317,28 @@ async def test_acquire_async_cancelled_waiter_does_not_block_successor(tmp_path,
pytest.fail("provider thread lock was not released after successor acquire_async")
@pytest.mark.anyio
async def test_acquire_internal_async_offloads_cached_reuse_health_check(tmp_path, monkeypatch):
"""Async cached reuse must keep backend health checks off the event loop."""
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
provider, _sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-cached-async")
provider._thread_sandboxes = {"thread-cached-async": "sandbox-cached-async"}
provider._backend.is_alive = MagicMock(return_value=True)
to_thread_calls: list[tuple[object, tuple[object, ...]]] = []
async def fake_to_thread(func, /, *args, **kwargs):
to_thread_calls.append((func, args))
return func(*args, **kwargs)
monkeypatch.setattr(aio_mod.asyncio, "to_thread", fake_to_thread)
sandbox_id = await provider._acquire_internal_async("thread-cached-async")
assert sandbox_id == "sandbox-cached-async"
assert to_thread_calls == [(provider._reuse_in_process_sandbox, ("thread-cached-async",))]
def test_remote_backend_create_forwards_effective_user_id(monkeypatch):
"""Provisioner mode must receive user_id so PVC subPath matches user isolation."""
remote_mod = importlib.import_module("deerflow.community.aio_sandbox.remote_backend")
@@ -424,6 +446,136 @@ def test_release_swallows_close_errors(tmp_path, caplog):
assert "sandbox-rel-err" in provider._warm_pool
def test_get_uses_in_memory_registry_only(tmp_path):
"""get() must stay event-loop safe by avoiding backend health checks."""
provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-dead")
provider._backend.is_alive = MagicMock(side_effect=AssertionError("get must not call backend health checks"))
assert provider.get("sandbox-dead") is sandbox
def test_acquire_drops_dead_cached_sandbox(tmp_path, monkeypatch):
"""acquire() must replace a stale active cache entry after its container dies."""
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-dead")
provider._thread_locks = {}
provider._thread_sandboxes = {"thread-dead": "sandbox-dead"}
provider._config = {"replicas": 3}
provider._backend.is_alive = MagicMock(return_value=False)
provider._backend.discover = MagicMock(return_value=None)
provider._backend.create = MagicMock(
return_value=aio_mod.SandboxInfo(
sandbox_id="sandbox-dead",
sandbox_url="http://fresh-sandbox",
container_name="deer-flow-sandbox-sandbox-dead",
)
)
monkeypatch.setattr(aio_mod.AioSandboxProvider, "_sandbox_id_for_thread", lambda _self, _thread_id: "sandbox-dead")
monkeypatch.setattr(aio_mod.AioSandboxProvider, "_get_extra_mounts", lambda _self, _thread_id: [])
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None)
monkeypatch.setattr(aio_mod, "wait_for_sandbox_ready", lambda _url, timeout=60: True)
sandbox_id = provider.acquire("thread-dead")
assert sandbox_id == "sandbox-dead"
sandbox.close.assert_called_once_with()
provider._backend.destroy.assert_called_once()
provider._backend.create.assert_called_once()
assert provider._thread_sandboxes["thread-dead"] == "sandbox-dead"
assert provider._sandboxes["sandbox-dead"].base_url == "http://fresh-sandbox"
def test_acquire_keeps_cached_sandbox_when_health_check_errors(tmp_path):
"""Transient backend health-check errors must not destroy a tracked sandbox."""
provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-transient")
provider._thread_locks = {}
provider._thread_sandboxes = {"thread-transient": "sandbox-transient"}
provider._backend.is_alive = MagicMock(side_effect=OSError("docker daemon busy"))
sandbox_id = provider.acquire("thread-transient")
assert sandbox_id == "sandbox-transient"
sandbox.close.assert_not_called()
provider._backend.destroy.assert_not_called()
assert provider._sandboxes["sandbox-transient"] is sandbox
def test_drop_unhealthy_sandbox_skips_recreated_entry(tmp_path):
"""A stale health-check result must not delete a newly registered sandbox."""
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
provider = _make_provider(tmp_path)
provider._lock = aio_mod.threading.Lock()
provider._warm_pool = {}
provider._last_activity = {"sandbox-toctou": 1.0}
provider._thread_sandboxes = {"thread-toctou": "sandbox-toctou"}
old_info = aio_mod.SandboxInfo(sandbox_id="sandbox-toctou", sandbox_url="http://old-sandbox")
new_info = aio_mod.SandboxInfo(sandbox_id="sandbox-toctou", sandbox_url="http://new-sandbox")
new_sandbox = MagicMock()
provider._sandbox_infos = {"sandbox-toctou": new_info}
provider._sandboxes = {"sandbox-toctou": new_sandbox}
provider._backend = SimpleNamespace(destroy=MagicMock())
provider._drop_unhealthy_sandbox("sandbox-toctou", "stale health check", expected_info=old_info)
new_sandbox.close.assert_not_called()
provider._backend.destroy.assert_not_called()
assert provider._sandbox_infos["sandbox-toctou"] is new_info
assert provider._sandboxes["sandbox-toctou"] is new_sandbox
assert provider._thread_sandboxes == {"thread-toctou": "sandbox-toctou"}
def test_acquire_skips_dead_warm_pool_sandbox(tmp_path, monkeypatch):
"""acquire() must create a fresh sandbox when the warm-pool entry died."""
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
provider = _make_provider(tmp_path)
provider._lock = aio_mod.threading.Lock()
provider._thread_locks = {}
provider._sandboxes = {}
provider._sandbox_infos = {}
provider._thread_sandboxes = {}
provider._last_activity = {}
provider._warm_pool = {
"sandbox-warm-dead": (
aio_mod.SandboxInfo(
sandbox_id="sandbox-warm-dead",
sandbox_url="http://stale-sandbox",
container_name="deer-flow-sandbox-sandbox-warm-dead",
),
0.0,
)
}
provider._config = {"replicas": 3}
provider._backend = SimpleNamespace(
is_alive=MagicMock(return_value=False),
destroy=MagicMock(),
discover=MagicMock(return_value=None),
create=MagicMock(
return_value=aio_mod.SandboxInfo(
sandbox_id="sandbox-warm-dead",
sandbox_url="http://fresh-sandbox",
container_name="deer-flow-sandbox-sandbox-warm-dead",
)
),
)
monkeypatch.setattr(aio_mod.AioSandboxProvider, "_sandbox_id_for_thread", lambda _self, _thread_id: "sandbox-warm-dead")
monkeypatch.setattr(aio_mod.AioSandboxProvider, "_get_extra_mounts", lambda _self, _thread_id: [])
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None)
monkeypatch.setattr(aio_mod, "wait_for_sandbox_ready", lambda _url, timeout=60: True)
sandbox_id = provider.acquire("thread-warm-dead")
assert sandbox_id == "sandbox-warm-dead"
provider._backend.destroy.assert_called_once()
provider._backend.create.assert_called_once()
assert provider._warm_pool == {}
assert provider._thread_sandboxes["thread-warm-dead"] == "sandbox-warm-dead"
assert provider._sandboxes["sandbox-warm-dead"].base_url == "http://fresh-sandbox"
def test_destroy_swallows_close_errors_and_still_destroys_backend(tmp_path, caplog):
"""A failure in sandbox.close() must not skip backend container destruction."""
provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-dest-err")
+191 -5
View File
@@ -4,6 +4,7 @@ import pytest
from starlette.testclient import TestClient
from app.gateway.auth_middleware import AuthMiddleware, _is_public
from app.gateway.csrf_middleware import CSRFMiddleware
# ── _is_public unit tests ─────────────────────────────────────────────────
@@ -38,6 +39,8 @@ def test_public_paths(path: str):
"/api/threads/123/uploads",
"/api/agents",
"/api/channels",
"/api/channels/providers",
"/api/channels/slack/connect",
"/api/runs/stream",
"/api/threads/123/runs",
"/api/v1/auth/me",
@@ -88,7 +91,9 @@ def test_unknown_api_path_is_protected():
def _make_app():
"""Create a minimal FastAPI app with AuthMiddleware for testing."""
from fastapi import FastAPI
from fastapi import FastAPI, Request
from deerflow.runtime.user_context import get_effective_user_id
app = FastAPI()
app.add_middleware(AuthMiddleware)
@@ -98,8 +103,16 @@ def _make_app():
return {"status": "ok"}
@app.get("/api/v1/auth/me")
async def auth_me():
return {"id": "1", "email": "test@test.com"}
async def auth_me(request: Request):
from app.gateway.deps import get_current_user_from_request
user = await get_current_user_from_request(request)
return {
"id": str(user.id),
"email": user.email,
"system_role": user.system_role,
"needs_setup": user.needs_setup,
}
@app.get("/api/v1/auth/setup-status")
async def setup_status():
@@ -109,6 +122,29 @@ def _make_app():
async def models_get():
return {"models": []}
@app.get("/api/whoami")
async def whoami(request: Request):
user = request.state.user
return {
"id": str(user.id),
"email": getattr(user, "email", None),
"system_role": getattr(user, "system_role", None),
"context_user_id": get_effective_user_id(),
}
@app.get("/api/current-user-from-dep")
async def current_user_from_dep(request: Request):
from app.gateway.deps import get_current_user_from_request
user = await get_current_user_from_request(request)
state_user = request.state.user
return {
"id": str(user.id),
"state_id": str(state_user.id),
"auth_source": request.state.auth_source,
"context_user_id": get_effective_user_id(),
}
@app.put("/api/mcp/config")
async def mcp_put():
return {"ok": True}
@@ -132,8 +168,24 @@ def _make_app():
return app
def _make_auth_csrf_app():
"""Create a minimal app with production middleware ordering."""
from fastapi import FastAPI
app = FastAPI()
app.add_middleware(AuthMiddleware)
app.add_middleware(CSRFMiddleware)
@app.post("/api/threads/abc/runs/stream")
async def protected_mutation():
return {"ok": True}
return app
@pytest.fixture
def client():
def client(monkeypatch):
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "")
return TestClient(_make_app())
@@ -161,11 +213,145 @@ def test_protected_path_no_cookie_returns_401(client):
assert body["detail"]["code"] == "not_authenticated"
def test_auth_disabled_allows_protected_path_without_cookie(monkeypatch):
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
client = TestClient(_make_app())
res = client.get("/api/models")
assert res.status_code == 200
assert res.json() == {"models": []}
def test_auth_disabled_stamps_default_admin_user_without_cookie(monkeypatch):
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
client = TestClient(_make_app())
res = client.get("/api/whoami")
assert res.status_code == 200
assert res.json() == {
"id": "default",
"email": "default@test.local",
"system_role": "admin",
"context_user_id": "default",
}
def test_auth_disabled_auth_me_reuses_middleware_user_without_cookie(monkeypatch):
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
client = TestClient(_make_app())
res = client.get("/api/v1/auth/me")
assert res.status_code == 200
assert res.json() == {
"id": "default",
"email": "default@test.local",
"system_role": "admin",
"needs_setup": False,
}
def test_auth_disabled_does_not_clobber_valid_session_cookie(monkeypatch):
from types import SimpleNamespace
async def fake_current_user(request):
return SimpleNamespace(
id="session-user",
email="session@test.local",
system_role="user",
needs_setup=False,
)
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
monkeypatch.setattr("app.gateway.deps.get_current_user_from_request", fake_current_user)
client = TestClient(_make_app())
res = client.get("/api/whoami", cookies={"access_token": "valid-session"})
assert res.status_code == 200
assert res.json() == {
"id": "session-user",
"email": "session@test.local",
"system_role": "user",
"context_user_id": "session-user",
}
def test_auth_disabled_does_not_clobber_internal_auth_identity(monkeypatch):
from app.gateway.internal_auth import create_internal_auth_headers
from deerflow.runtime.user_context import DEFAULT_USER_ID
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
client = TestClient(_make_app())
res = client.get(
"/api/current-user-from-dep",
headers=create_internal_auth_headers(),
)
assert res.status_code == 200
assert res.json() == {
"id": DEFAULT_USER_ID,
"state_id": DEFAULT_USER_ID,
"auth_source": "internal",
"context_user_id": DEFAULT_USER_ID,
}
def test_auth_disabled_skips_csrf_for_state_changing_requests(monkeypatch):
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
client = TestClient(_make_auth_csrf_app())
res = client.post("/api/threads/abc/runs/stream")
assert res.status_code == 200
assert res.json() == {"ok": True}
def test_auth_disabled_is_ignored_in_explicit_production_env(monkeypatch):
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
monkeypatch.setenv("DEER_FLOW_ENV", "production")
client = TestClient(_make_app())
res = client.get("/api/models")
assert res.status_code == 401
def test_auth_disabled_startup_warning_when_effective(monkeypatch, caplog):
from app.gateway.auth_disabled import warn_if_auth_disabled_enabled
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
monkeypatch.delenv("DEER_FLOW_ENV", raising=False)
monkeypatch.delenv("ENVIRONMENT", raising=False)
with caplog.at_level("WARNING", logger="app.gateway.auth_disabled"):
warn_if_auth_disabled_enabled()
assert "authentication is bypassed" in caplog.text
assert "default" in caplog.text
def test_auth_disabled_startup_warning_suppressed_in_explicit_production_env(monkeypatch, caplog):
from app.gateway.auth_disabled import warn_if_auth_disabled_enabled
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
monkeypatch.setenv("ENVIRONMENT", "production")
with caplog.at_level("WARNING", logger="app.gateway.auth_disabled"):
warn_if_auth_disabled_enabled()
assert "authentication is bypassed" not in caplog.text
def test_protected_path_with_junk_cookie_rejected(client):
"""Junk cookie → 401. Middleware strictly validates the JWT now
(AUTH_TEST_PLAN test 7.5.8); it no longer silently passes bad
tokens through to the route handler."""
res = client.get("/api/models", cookies={"access_token": "some-token"})
client.cookies.set("access_token", "some-token")
res = client.get("/api/models")
assert res.status_code == 401
@@ -0,0 +1,56 @@
"""Tests for user-facing IM channel connection configuration."""
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
def test_channel_connections_disabled_by_default():
config = ChannelConnectionsConfig()
assert config.enabled is False
assert config.slack.enabled is False
assert config.telegram.enabled is False
assert config.discord.enabled is False
assert config.feishu.enabled is False
assert config.dingtalk.enabled is False
assert config.wechat.enabled is False
assert config.wecom.enabled is False
def test_enabled_channel_connections_do_not_require_public_url_or_encryption_key():
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"telegram": {
"enabled": True,
"bot_username": "deerflow_bot",
},
"slack": {"enabled": True},
"discord": {"enabled": True},
"feishu": {"enabled": True},
"dingtalk": {"enabled": True},
"wechat": {"enabled": True},
"wecom": {"enabled": True},
}
)
assert config.enabled is True
assert config.provider_status("telegram") == {"enabled": True, "configured": True}
assert config.provider_status("slack") == {"enabled": True, "configured": True}
assert config.provider_status("discord") == {"enabled": True, "configured": True}
assert config.provider_status("feishu") == {"enabled": True, "configured": True}
assert config.provider_status("dingtalk") == {"enabled": True, "configured": True}
assert config.provider_status("wechat") == {"enabled": True, "configured": True}
assert config.provider_status("wecom") == {"enabled": True, "configured": True}
def test_provider_status_reports_disabled_and_unknown_providers():
config = ChannelConnectionsConfig.model_validate({"enabled": True})
assert config.provider_status("slack") == {"enabled": False, "configured": False}
assert config.provider_status("telegram") == {"enabled": False, "configured": False}
assert config.provider_status("discord") == {"enabled": False, "configured": False}
assert config.provider_status("feishu") == {"enabled": False, "configured": False}
assert config.provider_status("dingtalk") == {"enabled": False, "configured": False}
assert config.provider_status("wechat") == {"enabled": False, "configured": False}
assert config.provider_status("wecom") == {"enabled": False, "configured": False}
assert config.provider_status("unknown") == {"enabled": False, "configured": False}
@@ -0,0 +1,226 @@
"""Tests for per-user IM channel connection persistence."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
import pytest
from sqlalchemy import select
from deerflow.persistence.channel_connections import (
ChannelConnectionRepository,
ChannelConnectionRow,
ChannelCredentialCipher,
ChannelCredentialRow,
ChannelOAuthStateRow,
)
@pytest.fixture
async def repo(tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
url = f"sqlite+aiosqlite:///{tmp_path / 'channels.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
try:
yield ChannelConnectionRepository(
get_session_factory(),
cipher=ChannelCredentialCipher.from_key("test-encryption-key"),
)
finally:
await close_engine()
class TestChannelConnectionRepository:
@pytest.mark.anyio
async def test_connections_are_listed_per_owner(self, repo):
alice = await repo.upsert_connection(
owner_user_id="alice",
provider="slack",
external_account_id="U-alice",
external_account_name="Alice",
workspace_id="T1",
workspace_name="Team One",
scopes=["chat:write"],
)
await repo.upsert_connection(
owner_user_id="bob",
provider="slack",
external_account_id="U-bob",
external_account_name="Bob",
workspace_id="T1",
workspace_name="Team One",
scopes=["chat:write"],
)
results = await repo.list_connections("alice")
assert [item["id"] for item in results] == [alice["id"]]
assert results[0]["owner_user_id"] == "alice"
assert results[0]["provider"] == "slack"
assert results[0]["scopes"] == ["chat:write"]
assert "encrypted_access_token" not in results[0]
@pytest.mark.anyio
async def test_upsert_connection_updates_existing_provider_identity(self, repo):
first = await repo.upsert_connection(
owner_user_id="alice",
provider="telegram",
external_account_id="42",
external_account_name="Alice",
workspace_id=None,
workspace_name=None,
status="pending",
)
second = await repo.upsert_connection(
owner_user_id="alice",
provider="telegram",
external_account_id="42",
external_account_name="Alice Telegram",
workspace_id=None,
workspace_name=None,
status="connected",
)
assert second["id"] == first["id"]
assert second["status"] == "connected"
assert second["external_account_name"] == "Alice Telegram"
assert len(await repo.list_connections("alice")) == 1
@pytest.mark.anyio
async def test_credentials_are_encrypted_at_rest_and_decrypted_by_repository(self, repo):
connection = await repo.upsert_connection(
owner_user_id="alice",
provider="slack",
external_account_id="U-alice",
workspace_id="T1",
)
expires_at = datetime.now(UTC) + timedelta(hours=1)
await repo.store_credentials(
connection["id"],
access_token="xoxb-secret-access-token",
refresh_token="secret-refresh-token",
token_type="Bearer",
expires_at=expires_at,
extra={"bot_user_id": "B123"},
)
async with repo.session_factory() as session:
row = (await session.execute(select(ChannelCredentialRow))).scalar_one()
assert row.encrypted_access_token is not None
assert "xoxb-secret-access-token" not in row.encrypted_access_token
assert "secret-refresh-token" not in (row.encrypted_refresh_token or "")
assert "B123" not in (row.encrypted_extra_json or "")
credentials = await repo.get_credentials(connection["id"])
assert credentials is not None
assert credentials["access_token"] == "xoxb-secret-access-token"
assert credentials["refresh_token"] == "secret-refresh-token"
assert credentials["token_type"] == "Bearer"
assert credentials["expires_at"] == expires_at
assert credentials["extra"] == {"bot_user_id": "B123"}
@pytest.mark.anyio
async def test_conversations_are_scoped_by_connection(self, repo):
alice = await repo.upsert_connection(
owner_user_id="alice",
provider="slack",
external_account_id="U-alice",
workspace_id="T1",
)
bob = await repo.upsert_connection(
owner_user_id="bob",
provider="slack",
external_account_id="U-bob",
workspace_id="T1",
)
await repo.set_thread_id(
connection_id=alice["id"],
owner_user_id="alice",
provider="slack",
external_conversation_id="C-shared",
external_topic_id="1710000000.000100",
thread_id="thread-alice",
)
await repo.set_thread_id(
connection_id=bob["id"],
owner_user_id="bob",
provider="slack",
external_conversation_id="C-shared",
external_topic_id="1710000000.000100",
thread_id="thread-bob",
)
assert await repo.get_thread_id(alice["id"], "C-shared", "1710000000.000100") == "thread-alice"
assert await repo.get_thread_id(bob["id"], "C-shared", "1710000000.000100") == "thread-bob"
@pytest.mark.anyio
async def test_disconnect_connection_revokes_owner_connection_and_removes_credentials(self, repo):
connection = await repo.upsert_connection(
owner_user_id="alice",
provider="telegram",
external_account_id="42",
)
await repo.store_credentials(connection["id"], access_token="secret-token")
disconnected = await repo.disconnect_connection(
connection_id=connection["id"],
owner_user_id="alice",
)
assert disconnected is True
async with repo.session_factory() as session:
connection_row = await session.get(ChannelConnectionRow, connection["id"])
credential_row = await session.get(ChannelCredentialRow, connection["id"])
assert connection_row is not None
assert connection_row.status == "revoked"
assert credential_row is None
assert (
await repo.find_connection_by_external_identity(
provider="telegram",
external_account_id="42",
)
is None
)
@pytest.mark.anyio
async def test_disconnect_connection_is_owner_scoped(self, repo):
connection = await repo.upsert_connection(
owner_user_id="alice",
provider="telegram",
external_account_id="42",
)
disconnected = await repo.disconnect_connection(
connection_id=connection["id"],
owner_user_id="bob",
)
assert disconnected is False
assert (await repo.list_connections("alice"))[0]["status"] == "connected"
@pytest.mark.anyio
async def test_consume_oauth_state_deletes_expired_states(self, repo):
now = datetime.now(UTC)
await repo.create_oauth_state(
owner_user_id="alice",
provider="slack",
state="expired-state",
expires_at=now - timedelta(minutes=1),
)
await repo.create_oauth_state(
owner_user_id="alice",
provider="slack",
state="active-state",
expires_at=now + timedelta(minutes=5),
)
consumed = await repo.consume_oauth_state(provider="slack", state="expired-state", now=now)
assert consumed is None
async with repo.session_factory() as session:
states = (await session.execute(select(ChannelOAuthStateRow))).scalars().all()
assert [state.state_hash for state in states] == [repo.hash_state("active-state")]
@@ -0,0 +1,852 @@
"""Router tests for browser-connectable IM channels."""
from __future__ import annotations
from tempfile import TemporaryDirectory
from types import SimpleNamespace
from unittest.mock import AsyncMock
from uuid import UUID
import pytest
from _router_auth_helpers import make_authed_test_app
from fastapi.testclient import TestClient
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
from app.gateway.auth.models import User
from app.gateway.routers import channel_connections
from deerflow.config.app_config import AppConfig, reset_app_config, set_app_config
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
@pytest.fixture(autouse=True)
def _stub_app_config():
"""Keep router tests independent from a developer-local config.yaml."""
set_app_config(AppConfig.model_validate({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}))
yield
reset_app_config()
def _user() -> User:
return User(
id=UUID("11111111-2222-3333-4444-555555555555"),
email="alice@example.com",
password_hash="x",
system_role="user",
)
async def _make_repo(tmp_path):
from deerflow.persistence.channel_connections import ChannelConnectionRepository
from deerflow.persistence.engine import get_session_factory, init_engine
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'router.db'}", sqlite_dir=str(tmp_path))
return ChannelConnectionRepository(get_session_factory())
def _make_app(
config: ChannelConnectionsConfig,
repo,
channels_config: dict | None = None,
*,
runtime_config_store: ChannelRuntimeConfigStore | None = None,
set_channels_config_state: bool = True,
):
app = make_authed_test_app(user_factory=_user)
app.state.channel_connections_config = config
app.state.channel_connection_repo = repo
if set_channels_config_state:
app.state.channels_config = channels_config or {}
if runtime_config_store is None:
runtime_config_dir = TemporaryDirectory()
app.state.channel_runtime_config_tmpdir = runtime_config_dir
runtime_config_store = ChannelRuntimeConfigStore(f"{runtime_config_dir.name}/runtime-config.json")
app.state.channel_runtime_config_store = runtime_config_store
app.include_router(channel_connections.router)
return app
def _enabled_connections_config() -> ChannelConnectionsConfig:
return ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"telegram": {"enabled": True, "bot_username": "deerflow_bot"},
"slack": {"enabled": True},
"discord": {"enabled": True},
"feishu": {"enabled": True},
"dingtalk": {"enabled": True},
"wechat": {"enabled": True},
"wecom": {"enabled": True},
}
)
def _channels_config() -> dict:
return {
"telegram": {"enabled": True, "bot_token": "telegram-token"},
"slack": {"enabled": True, "bot_token": "xoxb-operator", "app_token": "xapp-operator"},
"discord": {"enabled": True, "bot_token": "discord-bot"},
"feishu": {"enabled": True, "app_id": "feishu-app", "app_secret": "feishu-secret"},
"dingtalk": {"enabled": True, "client_id": "dingtalk-client", "client_secret": "dingtalk-secret"},
"wechat": {"enabled": True, "bot_token": "wechat-token"},
"wecom": {"enabled": True, "bot_id": "wecom-bot", "bot_secret": "wecom-secret"},
}
def test_get_providers_only_returns_enabled_channels_and_setup_fields(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"slack": {"enabled": True},
"discord": {"enabled": False},
}
)
app = _make_app(config, repo, {})
with TestClient(app) as client:
response = client.get("/api/channels/providers")
assert response.status_code == 200
body = response.json()
assert body["enabled"] is True
assert [provider["provider"] for provider in body["providers"]] == ["slack"]
assert body["providers"][0]["configured"] is False
assert body["providers"][0]["connectable"] is False
assert body["providers"][0]["credential_fields"] == [
{
"name": "bot_token",
"label": "Bot token",
"type": "password",
"required": True,
},
{
"name": "app_token",
"label": "App token",
"type": "password",
"required": True,
},
]
anyio.run(repo.close)
def test_get_providers_uses_existing_channels_config(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
with TestClient(app) as client:
response = client.get("/api/channels/providers")
assert response.status_code == 200
body = response.json()
assert body["enabled"] is True
by_provider = {item["provider"]: item for item in body["providers"]}
assert set(by_provider) == {"telegram", "slack", "discord", "feishu", "dingtalk", "wechat", "wecom"}
assert by_provider["telegram"]["configured"] is True
assert by_provider["telegram"]["auth_mode"] == "deep_link"
assert by_provider["telegram"]["credential_values"] == {
"bot_token": "********",
"bot_username": "deerflow_bot",
}
assert by_provider["slack"]["configured"] is True
assert by_provider["slack"]["auth_mode"] == "binding_code"
assert by_provider["slack"]["connection_status"] == "connected"
assert by_provider["slack"]["credential_values"] == {
"bot_token": "********",
"app_token": "********",
}
assert by_provider["discord"]["configured"] is True
assert by_provider["discord"]["auth_mode"] == "binding_code"
assert by_provider["discord"]["credential_values"] == {"bot_token": "********"}
assert by_provider["feishu"]["configured"] is True
assert by_provider["feishu"]["auth_mode"] == "binding_code"
assert by_provider["feishu"]["connection_status"] == "connected"
assert by_provider["feishu"]["credential_values"] == {
"app_id": "feishu-app",
"app_secret": "********",
}
assert by_provider["dingtalk"]["configured"] is True
assert by_provider["dingtalk"]["auth_mode"] == "binding_code"
assert by_provider["dingtalk"]["credential_values"] == {
"client_id": "dingtalk-client",
"client_secret": "********",
}
assert by_provider["wechat"]["configured"] is True
assert by_provider["wechat"]["auth_mode"] == "binding_code"
assert by_provider["wechat"]["credential_values"] == {"bot_token": "********"}
assert by_provider["wecom"]["configured"] is True
assert by_provider["wecom"]["auth_mode"] == "binding_code"
assert by_provider["wecom"]["credential_values"] == {
"bot_id": "wecom-bot",
"bot_secret": "********",
}
anyio.run(repo.close)
def test_get_providers_reports_unconfigured_when_runtime_channel_is_missing(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
app = _make_app(_enabled_connections_config(), repo, {"telegram": {"enabled": True, "bot_token": "telegram-token"}})
with TestClient(app) as client:
response = client.get("/api/channels/providers")
assert response.status_code == 200
by_provider = {item["provider"]: item for item in response.json()["providers"]}
assert by_provider["telegram"]["configured"] is True
assert by_provider["slack"]["configured"] is False
assert by_provider["slack"]["connectable"] is False
assert "Slack credentials" in by_provider["slack"]["unavailable_reason"]
assert by_provider["discord"]["configured"] is False
assert "Discord credentials" in by_provider["discord"]["unavailable_reason"]
assert by_provider["feishu"]["configured"] is False
assert "Feishu credentials" in by_provider["feishu"]["unavailable_reason"]
assert by_provider["dingtalk"]["configured"] is False
assert "DingTalk credentials" in by_provider["dingtalk"]["unavailable_reason"]
assert by_provider["wechat"]["configured"] is False
assert "WeChat credentials" in by_provider["wechat"]["unavailable_reason"]
assert by_provider["wecom"]["configured"] is False
assert "WeCom credentials" in by_provider["wecom"]["unavailable_reason"]
anyio.run(repo.close)
def test_get_providers_reports_configured_channel_not_running(tmp_path, monkeypatch):
import anyio
repo = anyio.run(_make_repo, tmp_path)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
service = SimpleNamespace(
get_status=lambda: {
"service_running": True,
"channels": {
"feishu": {
"enabled": True,
"running": False,
}
},
}
)
monkeypatch.setattr("app.channels.service.get_channel_service", lambda: service)
with TestClient(app) as client:
response = client.get("/api/channels/providers")
assert response.status_code == 200
by_provider = {item["provider"]: item for item in response.json()["providers"]}
assert by_provider["feishu"]["configured"] is True
assert by_provider["feishu"]["connectable"] is False
assert by_provider["feishu"]["connection_status"] == "not_connected"
assert "configured but is not running" in by_provider["feishu"]["unavailable_reason"]
anyio.run(repo.close)
def test_get_providers_restarts_configured_channel_when_service_can_reconcile(tmp_path, monkeypatch):
import anyio
repo = anyio.run(_make_repo, tmp_path)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"feishu": {"enabled": True},
}
)
channels_config = {
"feishu": {
"enabled": True,
"app_id": "feishu-app",
"app_secret": "feishu-secret",
}
}
app = _make_app(config, repo, channels_config)
status = {
"service_running": True,
"channels": {
"feishu": {
"enabled": True,
"running": False,
}
},
}
reconciled: list[tuple[str, dict]] = []
async def ensure_channel_ready(provider, runtime_config):
reconciled.append((provider, dict(runtime_config)))
status["channels"][provider]["running"] = True
return True
service = SimpleNamespace(
get_status=lambda: status,
ensure_channel_ready=ensure_channel_ready,
)
monkeypatch.setattr("app.channels.service.get_channel_service", lambda: service)
with TestClient(app) as client:
response = client.get("/api/channels/providers")
assert response.status_code == 200
by_provider = {item["provider"]: item for item in response.json()["providers"]}
assert by_provider["feishu"]["configured"] is True
assert by_provider["feishu"]["connectable"] is True
assert by_provider["feishu"]["connection_status"] == "connected"
assert by_provider["feishu"]["unavailable_reason"] is None
assert reconciled == [("feishu", channels_config["feishu"])]
anyio.run(repo.close)
def test_get_providers_uses_newest_connection_status_per_provider(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
async def seed_connections():
await repo.upsert_connection(
owner_user_id=str(_user().id),
provider="slack",
external_account_id="U-old",
workspace_id="T-old",
status="revoked",
)
await anyio.sleep(0.01)
await repo.upsert_connection(
owner_user_id=str(_user().id),
provider="slack",
external_account_id="U-new",
workspace_id="T-new",
status="connected",
)
anyio.run(seed_connections)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
with TestClient(app) as client:
response = client.get("/api/channels/providers")
assert response.status_code == 200
by_provider = {item["provider"]: item for item in response.json()["providers"]}
assert by_provider["slack"]["connection_status"] == "connected"
anyio.run(repo.close)
def test_get_connections_returns_current_user_connections_only(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
async def seed_connections():
await repo.upsert_connection(
owner_user_id=str(_user().id),
provider="telegram",
external_account_id="42",
external_account_name="Alice",
status="connected",
)
await repo.upsert_connection(
owner_user_id="other-user",
provider="telegram",
external_account_id="99",
external_account_name="Bob",
status="connected",
)
anyio.run(seed_connections)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
with TestClient(app) as client:
response = client.get("/api/channels/connections")
assert response.status_code == 200
body = response.json()
assert len(body["connections"]) == 1
assert body["connections"][0]["provider"] == "telegram"
assert body["connections"][0]["external_account_id"] == "42"
anyio.run(repo.close)
def test_connect_telegram_returns_deep_link_and_persists_state(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
with TestClient(app) as client:
response = client.post("/api/channels/telegram/connect")
assert response.status_code == 200
body = response.json()
assert body["provider"] == "telegram"
assert body["mode"] == "deep_link"
assert body["url"].startswith("https://t.me/deerflow_bot?start=")
assert body["code"]
assert "/start" in body["instruction"]
async def count_states():
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="telegram")
assert anyio.run(count_states) == 1
anyio.run(repo.close)
def test_connect_slack_returns_binding_command_and_persists_state(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
with TestClient(app) as client:
response = client.post("/api/channels/slack/connect")
assert response.status_code == 200
body = response.json()
assert body["provider"] == "slack"
assert body["mode"] == "binding_code"
assert body["url"] is None
assert len(body["code"]) >= 22
assert body["instruction"] == f"Send /connect {body['code']} to the DeerFlow Slack bot."
async def count_states():
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="slack")
assert anyio.run(count_states) == 1
anyio.run(repo.close)
def test_connect_discord_returns_binding_command_and_persists_state(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
with TestClient(app) as client:
response = client.post("/api/channels/discord/connect")
assert response.status_code == 200
body = response.json()
assert body["provider"] == "discord"
assert body["mode"] == "binding_code"
assert body["url"] is None
assert body["code"]
assert body["instruction"] == f"Send /connect {body['code']} to the DeerFlow Discord bot."
async def count_states():
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="discord")
assert anyio.run(count_states) == 1
anyio.run(repo.close)
def test_connect_existing_binding_code_channels_return_command_and_persist_state(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
providers = ["feishu", "dingtalk", "wechat", "wecom"]
with TestClient(app) as client:
responses = {provider: client.post(f"/api/channels/{provider}/connect") for provider in providers}
for provider, response in responses.items():
expected_display_name = {
"feishu": "Feishu",
"dingtalk": "DingTalk",
"wechat": "WeChat",
"wecom": "WeCom",
}[provider]
assert response.status_code == 200
body = response.json()
assert body["provider"] == provider
assert body["mode"] == "binding_code"
assert body["url"] is None
assert len(body["code"]) >= 22
assert body["instruction"] == f"Send /connect {body['code']} to the DeerFlow {expected_display_name} bot."
async def count_states(provider=provider):
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider=provider)
assert anyio.run(count_states) == 1
anyio.run(repo.close)
def test_connect_unconfigured_runtime_channel_returns_400(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
app = _make_app(_enabled_connections_config(), repo, {})
with TestClient(app) as client:
response = client.post("/api/channels/slack/connect")
assert response.status_code == 400
assert "Slack credentials" in response.json()["detail"]
anyio.run(repo.close)
def test_configure_provider_runtime_credentials_enables_connect_without_file_edits(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"slack": {"enabled": True},
}
)
app = _make_app(config, repo, {})
with TestClient(app) as client:
configure_response = client.post(
"/api/channels/slack/runtime-config",
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
)
connect_response = client.post("/api/channels/slack/connect")
assert configure_response.status_code == 200
configured = configure_response.json()
assert configured["provider"] == "slack"
assert configured["configured"] is True
assert configured["connectable"] is True
assert configured["connection_status"] == "connected"
assert app.state.channels_config["slack"] == {
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
}
assert connect_response.status_code == 200
assert connect_response.json()["provider"] == "slack"
anyio.run(repo.close)
def test_configure_provider_runtime_credentials_survive_local_restart(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"slack": {"enabled": True},
}
)
runtime_config_path = tmp_path / "channels" / "runtime-config.json"
first_app = _make_app(
config,
repo,
{},
runtime_config_store=ChannelRuntimeConfigStore(runtime_config_path),
)
with TestClient(first_app) as client:
configure_response = client.post(
"/api/channels/slack/runtime-config",
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
)
assert configure_response.status_code == 200
restarted_app = _make_app(
config,
repo,
runtime_config_store=ChannelRuntimeConfigStore(runtime_config_path),
set_channels_config_state=False,
)
with TestClient(restarted_app) as client:
response = client.get("/api/channels/providers")
assert response.status_code == 200
by_provider = {item["provider"]: item for item in response.json()["providers"]}
assert by_provider["slack"]["configured"] is True
assert by_provider["slack"]["connectable"] is True
assert by_provider["slack"]["connection_status"] == "connected"
assert restarted_app.state.channels_config["slack"] == {
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
}
anyio.run(repo.close)
def test_configure_provider_runtime_credentials_preserves_masked_secrets(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"feishu": {"enabled": True},
}
)
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
app = _make_app(
config,
repo,
{
"feishu": {
"enabled": True,
"app_id": "old-app-id",
"app_secret": "old-secret",
}
},
runtime_config_store=runtime_config_store,
)
with TestClient(app) as client:
configure_response = client.post(
"/api/channels/feishu/runtime-config",
json={
"values": {
"app_id": "new-app-id",
"app_secret": "********",
}
},
)
providers_response = client.get("/api/channels/providers")
assert configure_response.status_code == 200
assert app.state.channels_config["feishu"] == {
"enabled": True,
"app_id": "new-app-id",
"app_secret": "old-secret",
}
assert runtime_config_store.get_provider_config("feishu") == {
"enabled": True,
"app_id": "new-app-id",
"app_secret": "old-secret",
}
by_provider = {item["provider"]: item for item in providers_response.json()["providers"]}
assert by_provider["feishu"]["credential_values"] == {
"app_id": "new-app-id",
"app_secret": "********",
}
anyio.run(repo.close)
def test_disconnect_provider_runtime_config_clears_connected_state(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"slack": {"enabled": True},
}
)
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
app = _make_app(config, repo, {}, runtime_config_store=runtime_config_store)
with TestClient(app) as client:
configure_response = client.post(
"/api/channels/slack/runtime-config",
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
)
disconnect_response = client.delete("/api/channels/slack/runtime-config")
providers_response = client.get("/api/channels/providers")
assert configure_response.status_code == 200
assert disconnect_response.status_code == 200
disconnected = disconnect_response.json()
assert disconnected["provider"] == "slack"
assert disconnected["configured"] is False
assert disconnected["connectable"] is False
assert disconnected["connection_status"] == "not_connected"
assert runtime_config_store.get_provider_config("slack") == {
"enabled": False,
"_runtime_disabled": True,
}
assert providers_response.status_code == 200
by_provider = {item["provider"]: item for item in providers_response.json()["providers"]}
assert by_provider["slack"]["connection_status"] == "not_connected"
anyio.run(repo.close)
def test_disconnect_provider_runtime_config_suppresses_file_config_and_stops_channel(tmp_path, monkeypatch):
import anyio
repo = anyio.run(_make_repo, tmp_path)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"feishu": {"enabled": True},
}
)
set_app_config(
AppConfig.model_validate(
{
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
"channels": {
"feishu": {
"enabled": True,
"app_id": "file-app-id",
"app_secret": "file-secret",
}
},
}
)
)
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
runtime_config_store.set_provider_config(
"feishu",
{
"enabled": True,
"app_id": "runtime-app-id",
"app_secret": "runtime-secret",
},
)
service = SimpleNamespace(
configure_channel=AsyncMock(return_value=True),
remove_channel=AsyncMock(return_value=True),
)
monkeypatch.setattr("app.channels.service.get_channel_service", lambda: service)
app = _make_app(
config,
repo,
{
"feishu": {
"enabled": True,
"app_id": "runtime-app-id",
"app_secret": "runtime-secret",
}
},
runtime_config_store=runtime_config_store,
)
with TestClient(app) as client:
disconnect_response = client.delete("/api/channels/feishu/runtime-config")
providers_response = client.get("/api/channels/providers")
assert disconnect_response.status_code == 200
disconnected = disconnect_response.json()
assert disconnected["provider"] == "feishu"
assert disconnected["configured"] is False
assert disconnected["connectable"] is False
assert disconnected["connection_status"] == "not_connected"
assert "feishu" not in app.state.channels_config
service.remove_channel.assert_awaited_once_with("feishu")
service.configure_channel.assert_not_awaited()
assert providers_response.status_code == 200
by_provider = {item["provider"]: item for item in providers_response.json()["providers"]}
assert by_provider["feishu"]["configured"] is False
assert by_provider["feishu"]["connection_status"] == "not_connected"
anyio.run(repo.close)
def test_disconnect_provider_runtime_config_revokes_current_user_provider_connections(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
async def seed_connection():
await repo.upsert_connection(
owner_user_id=str(_user().id),
provider="slack",
external_account_id="U123",
status="connected",
)
anyio.run(seed_connection)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"slack": {"enabled": True},
}
)
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
app = _make_app(config, repo, {}, runtime_config_store=runtime_config_store)
with TestClient(app) as client:
configure_response = client.post(
"/api/channels/slack/runtime-config",
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
)
disconnect_response = client.delete("/api/channels/slack/runtime-config")
assert configure_response.status_code == 200
assert disconnect_response.status_code == 200
async def get_connection_status():
return (await repo.list_connections(str(_user().id)))[0]["status"]
assert anyio.run(get_connection_status) == "revoked"
anyio.run(repo.close)
def test_disconnect_connection_revokes_current_user_connection(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
async def seed_connection():
connection = await repo.upsert_connection(
owner_user_id=str(_user().id),
provider="telegram",
external_account_id="42",
status="connected",
)
return connection["id"]
connection_id = anyio.run(seed_connection)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
with TestClient(app) as client:
response = client.delete(f"/api/channels/connections/{connection_id}")
assert response.status_code == 204
async def get_connection_status():
return (await repo.list_connections(str(_user().id)))[0]["status"]
assert anyio.run(get_connection_status) == "revoked"
anyio.run(repo.close)
def test_disconnect_connection_is_current_user_scoped(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
async def seed_connection():
connection = await repo.upsert_connection(
owner_user_id="other-user",
provider="telegram",
external_account_id="42",
status="connected",
)
return connection["id"]
connection_id = anyio.run(seed_connection)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
with TestClient(app) as client:
response = client.delete(f"/api/channels/connections/{connection_id}")
assert response.status_code == 404
async def get_connection_status():
return (await repo.list_connections("other-user"))[0]["status"]
assert anyio.run(get_connection_status) == "connected"
anyio.run(repo.close)
+456 -10
View File
@@ -487,6 +487,7 @@ def _make_mock_langgraph_client(thread_id="test-thread-123", run_result=None):
# threads.create() returns a Thread-like dict
mock_client.threads.create = AsyncMock(return_value={"thread_id": thread_id})
mock_client.threads.update = AsyncMock(return_value={"thread_id": thread_id})
# threads.get() returns thread info (succeeds by default)
mock_client.threads.get = AsyncMock(return_value={"thread_id": thread_id})
@@ -504,6 +505,17 @@ def _make_mock_langgraph_client(thread_id="test-thread-123", run_result=None):
return mock_client
async def _make_channel_connection_repo(tmp_path: Path):
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
from deerflow.persistence.engine import get_session_factory, init_engine
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'channel-connections.db'}", sqlite_dir=str(tmp_path))
return ChannelConnectionRepository(
get_session_factory(),
cipher=ChannelCredentialCipher.from_key("test-channel-key"),
)
def _make_stream_part(event: str, data):
return SimpleNamespace(event=event, data=data)
@@ -656,16 +668,34 @@ class TestChannelManager:
await manager.start()
inbound = InboundMessage(channel_name="test", chat_id="chat1", user_id="user1", text="hi")
inbound = InboundMessage(
channel_name="test",
chat_id="chat1",
user_id="user1",
text="hi",
topic_id="topic1",
thread_ts="msg1",
connection_id="conn1",
)
await bus.publish_inbound(inbound)
await _wait_for(lambda: len(outbound_received) >= 1)
await manager.stop()
# Thread should be created through Gateway
mock_client.threads.create.assert_called_once()
assert mock_client.threads.create.call_args.kwargs["metadata"] == {
"channel_source": {
"type": "im_channel",
"provider": "test",
"chat_id": "chat1",
"topic_id": "topic1",
"thread_ts": "msg1",
"connection_id": "conn1",
}
}
# Thread ID should be stored
thread_id = store.get_thread_id("test", "chat1")
thread_id = store.get_thread_id("test", "chat1", topic_id="topic1")
assert thread_id == "test-thread-123"
# runs.wait should be called with the thread_id
@@ -883,10 +913,12 @@ class TestChannelManager:
_run(go())
def test_clarification_follow_up_preserves_history(self):
def test_clarification_follow_up_preserves_history(self, monkeypatch):
"""Conversation should continue after ask_clarification instead of resetting history."""
from app.channels.manager import ChannelManager
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
async def go():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
@@ -1954,10 +1986,12 @@ class TestChannelManager:
_run(go())
def test_same_topic_reuses_thread(self):
def test_same_topic_reuses_thread(self, monkeypatch):
"""Messages with the same topic_id should reuse the same DeerFlow thread."""
from app.channels.manager import ChannelManager
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
async def go():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
@@ -1990,6 +2024,17 @@ class TestChannelManager:
# threads.create should be called only ONCE (second message reuses the thread)
mock_client.threads.create.assert_called_once()
mock_client.threads.update.assert_called_once_with(
"topic-thread-1",
metadata={
"channel_source": {
"type": "im_channel",
"provider": "test",
"chat_id": "chat1",
"topic_id": "topic-root-123",
}
},
)
# Both runs.wait calls should use the same thread_id
assert mock_client.runs.wait.call_count == 2
@@ -2325,8 +2370,9 @@ class TestResolveRunParamsUserId:
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
return ChannelManager(bus=bus, store=store)
def test_safe_user_id_is_passed_through(self):
def test_safe_user_id_is_passed_through(self, monkeypatch):
manager = self._manager()
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
msg = InboundMessage(channel_name="telegram", chat_id="c", user_id="123456", text="hi")
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
@@ -2334,10 +2380,78 @@ class TestResolveRunParamsUserId:
assert run_context["user_id"] == "123456"
assert run_context["channel_user_id"] == "123456"
def test_unsafe_user_id_is_normalized_but_raw_preserved(self):
def test_connection_owner_user_id_takes_precedence_over_platform_user_id(self, monkeypatch):
manager = self._manager()
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
msg = InboundMessage(
channel_name="slack",
chat_id="C123",
user_id="U-platform",
owner_user_id="deerflow-user-1",
connection_id="connection-1",
text="hi",
)
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
assert run_context["user_id"] == "deerflow-user-1"
assert run_context["channel_user_id"] == "U-platform"
def test_auth_disabled_user_id_is_used_for_unbound_channel_messages(self, monkeypatch):
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
manager = self._manager()
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
msg = InboundMessage(channel_name="slack", chat_id="C123", user_id="U-platform", text="hi")
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
assert run_context["user_id"] == AUTH_DISABLED_USER_ID
assert run_context["channel_user_id"] == "U-platform"
from app.channels.manager import _owner_headers
headers = _owner_headers(msg)
assert headers is not None
assert headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == AUTH_DISABLED_USER_ID
def test_auth_disabled_user_id_overrides_bound_owner_for_local_visibility(self, monkeypatch):
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID
manager = self._manager()
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
msg = InboundMessage(
channel_name="slack",
chat_id="C123",
user_id="U-platform",
owner_user_id="real-user-from-old-binding",
text="hi",
)
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
assert run_context["user_id"] == AUTH_DISABLED_USER_ID
assert run_context["channel_user_id"] == "U-platform"
def test_unbound_channel_messages_keep_platform_user_id_when_auth_is_enabled(self, monkeypatch):
from app.channels.manager import _owner_headers
manager = self._manager()
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
msg = InboundMessage(channel_name="slack", chat_id="C123", user_id="U-platform", text="hi")
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
assert run_context["user_id"] == "U-platform"
assert run_context["channel_user_id"] == "U-platform"
assert _owner_headers(msg) is None
def test_unsafe_user_id_is_normalized_but_raw_preserved(self, monkeypatch):
from deerflow.config.paths import make_safe_user_id
manager = self._manager()
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
raw = "user@example.com"
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi")
@@ -2347,9 +2461,32 @@ class TestResolveRunParamsUserId:
assert run_context["user_id"] != raw
assert run_context["channel_user_id"] == raw
@pytest.mark.parametrize("raw_user_id", ["", None])
def test_empty_or_none_user_id_is_not_injected(self, raw_user_id):
def test_unsafe_user_id_migrates_unique_legacy_bucket(self, tmp_path, monkeypatch):
from deerflow.config.paths import Paths, make_safe_user_id
paths = Paths(tmp_path)
legacy_dir = paths.base_dir / "users" / "user-example-com-63a710569261a24b"
legacy_dir.mkdir(parents=True)
(legacy_dir / "memory.json").write_text('{"legacy": true}\n', encoding="utf-8")
monkeypatch.setattr("deerflow.config.paths.get_paths", lambda: paths)
manager = self._manager()
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
raw = "user@example.com"
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi")
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
safe = make_safe_user_id(raw)
assert run_context["user_id"] == safe
assert paths.user_dir(safe).exists()
assert not legacy_dir.exists()
assert (paths.user_dir(safe) / "memory.json").read_text(encoding="utf-8") == '{"legacy": true}\n'
@pytest.mark.parametrize("raw_user_id", ["", None])
def test_empty_or_none_user_id_is_not_injected(self, raw_user_id, monkeypatch):
manager = self._manager()
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw_user_id, text="hi")
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
@@ -2358,6 +2495,93 @@ class TestResolveRunParamsUserId:
assert "channel_user_id" not in run_context
class TestChannelManagerConnectionRouting:
def test_connection_scoped_conversations_do_not_share_threads(self, tmp_path, monkeypatch):
from app.channels.manager import ChannelManager
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
from deerflow.persistence.engine import close_engine
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
async def go():
repo = await _make_channel_connection_repo(tmp_path)
alice = await repo.upsert_connection(
owner_user_id="alice",
provider="slack",
external_account_id="U-alice",
workspace_id="T1",
)
bob = await repo.upsert_connection(
owner_user_id="bob",
provider="slack",
external_account_id="U-bob",
workspace_id="T1",
)
bus = MessageBus()
store = ChannelStore(path=tmp_path / "legacy-store.json")
manager = ChannelManager(bus=bus, store=store, connection_repo=repo)
mock_client = _make_mock_langgraph_client()
mock_client.threads.create = AsyncMock(
side_effect=[
{"thread_id": "thread-alice"},
{"thread_id": "thread-bob"},
]
)
manager._client = mock_client
await manager._handle_chat(
InboundMessage(
channel_name="slack",
chat_id="C-shared",
user_id="U-alice",
owner_user_id="alice",
connection_id=alice["id"],
text="hello",
thread_ts="1710000000.000100",
topic_id="1710000000.000100",
)
)
await manager._handle_chat(
InboundMessage(
channel_name="slack",
chat_id="C-shared",
user_id="U-bob",
owner_user_id="bob",
connection_id=bob["id"],
text="hello",
thread_ts="1710000000.000100",
topic_id="1710000000.000100",
)
)
assert await repo.get_thread_id(alice["id"], "C-shared", "1710000000.000100") == "thread-alice"
assert await repo.get_thread_id(bob["id"], "C-shared", "1710000000.000100") == "thread-bob"
assert store.list_entries() == []
first_context = mock_client.runs.wait.call_args_list[0].kwargs["context"]
second_context = mock_client.runs.wait.call_args_list[1].kwargs["context"]
assert first_context["user_id"] == "alice"
assert first_context["channel_user_id"] == "U-alice"
assert second_context["user_id"] == "bob"
assert second_context["channel_user_id"] == "U-bob"
first_create_headers = mock_client.threads.create.call_args_list[0].kwargs["headers"]
second_create_headers = mock_client.threads.create.call_args_list[1].kwargs["headers"]
assert first_create_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "alice"
assert second_create_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "bob"
first_run_headers = mock_client.runs.wait.call_args_list[0].kwargs["headers"]
second_run_headers = mock_client.runs.wait.call_args_list[1].kwargs["headers"]
assert first_run_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "alice"
assert second_run_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "bob"
try:
_run(go())
finally:
_run(close_engine())
# ---------------------------------------------------------------------------
# ChannelService tests
# ---------------------------------------------------------------------------
@@ -3175,6 +3399,226 @@ class TestChannelService:
assert service._config == {"telegram": {"enabled": False}}
def test_from_app_config_does_not_create_runtime_channels_from_channel_connections(
self,
monkeypatch,
tmp_path,
):
from app.channels.service import ChannelService
from deerflow.config import paths as paths_module
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
monkeypatch.setattr(paths_module, "_paths", None)
app_config = SimpleNamespace(
model_extra={},
channel_connections=ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"telegram": {"enabled": True, "bot_username": "deerflow_bot"},
"slack": {"enabled": True},
"discord": {"enabled": True},
}
),
)
service = ChannelService.from_app_config(app_config)
assert service._config == {}
def test_from_app_config_preserves_existing_runtime_channels_with_channel_connections_enabled(
self,
monkeypatch,
tmp_path,
):
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
from app.channels.service import ChannelService
from deerflow.config import paths as paths_module
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
monkeypatch.setattr(paths_module, "_paths", None)
ChannelRuntimeConfigStore().set_provider_config(
"slack",
{
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
},
)
app_config = SimpleNamespace(
model_extra={
"channels": {
"telegram": {"enabled": True, "bot_token": "telegram-token"},
"slack": {"enabled": True, "bot_token": "xoxb", "app_token": "xapp"},
"discord": {"enabled": True, "bot_token": "discord-bot-token"},
}
},
channel_connections=ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"telegram": {"enabled": True, "bot_username": "deerflow_bot"},
"slack": {"enabled": True},
"discord": {"enabled": True},
}
),
)
service = ChannelService.from_app_config(app_config)
assert service._config["telegram"]["bot_token"] == "telegram-token"
assert service._config["slack"]["app_token"] == "xapp"
assert service._config["discord"]["bot_token"] == "discord-bot-token"
def test_from_app_config_loads_persisted_runtime_channel_config(self, monkeypatch, tmp_path):
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
from app.channels.service import ChannelService
from deerflow.config import paths as paths_module
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
monkeypatch.setattr(paths_module, "_paths", None)
ChannelRuntimeConfigStore().set_provider_config(
"slack",
{
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
},
)
app_config = SimpleNamespace(
model_extra={},
channel_connections=ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"slack": {"enabled": True},
}
),
)
service = ChannelService.from_app_config(app_config)
assert service._config["slack"] == {
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
}
def test_from_app_config_runtime_disconnect_suppresses_file_channel_config(self, monkeypatch, tmp_path):
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
from app.channels.service import ChannelService
from deerflow.config import paths as paths_module
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
monkeypatch.setattr(paths_module, "_paths", None)
ChannelRuntimeConfigStore().set_provider_config(
"feishu",
{
"enabled": False,
"_runtime_disabled": True,
},
)
app_config = SimpleNamespace(
model_extra={
"channels": {
"feishu": {
"enabled": True,
"app_id": "file-app-id",
"app_secret": "file-secret",
}
}
},
channel_connections=ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"feishu": {"enabled": True},
}
),
)
service = ChannelService.from_app_config(app_config)
assert "feishu" not in service._config
def test_start_retries_configured_channel_until_ready(self, monkeypatch):
from app.channels.service import ChannelService
class FlakyReadyChannel(Channel):
starts = 0
def __init__(self, bus, config):
super().__init__(name="slack", bus=bus, config=config)
async def start(self):
type(self).starts += 1
self._running = type(self).starts >= 2
async def stop(self):
self._running = False
async def send(self, msg):
return None
monkeypatch.setattr(
"deerflow.reflection.resolve_class",
lambda import_path, base_class=None: FlakyReadyChannel,
)
async def go():
service = ChannelService(
channels_config={
"slack": {
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
},
}
)
try:
await service.start()
assert FlakyReadyChannel.starts == 2
assert service.get_status()["channels"]["slack"]["running"] is True
finally:
await service.stop()
_run(go())
def test_connection_repo_is_forwarded_to_manager(self):
from app.channels.service import ChannelService
repo = object()
service = ChannelService(channels_config={}, connection_repo=repo)
assert service.manager._connection_repo is repo
def test_remove_channel_stops_running_channel_and_forgets_config(self):
from app.channels.service import ChannelService
async def go():
service = ChannelService(
channels_config={
"slack": {
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
},
}
)
channel = AsyncMock()
service._channels["slack"] = channel
service._running = True
assert await service.remove_channel("slack") is True
channel.stop.assert_awaited_once()
assert "slack" not in service._channels
assert "slack" not in service._config
_run(go())
def test_disabled_channel_with_string_creds_emits_warning(self, caplog):
"""Warning is emitted when a channel has string credentials but enabled=false."""
import logging
@@ -3192,7 +3636,8 @@ class TestChannelService:
await service.stop()
_run(go())
assert any("wecom" in r.message and r.levelno == logging.WARNING for r in caplog.records)
assert any("credentials configured but is disabled" in r.message and r.levelno == logging.WARNING for r in caplog.records)
assert all("wecom" not in r.message for r in caplog.records)
def test_disabled_channel_with_int_creds_emits_warning(self, caplog):
"""Warning is emitted even when YAML-parsed integer credentials are present."""
@@ -3212,7 +3657,8 @@ class TestChannelService:
await service.stop()
_run(go())
assert any("telegram" in r.message and r.levelno == logging.WARNING for r in caplog.records)
assert any("credentials configured but is disabled" in r.message and r.levelno == logging.WARNING for r in caplog.records)
assert all("telegram" not in r.message for r in caplog.records)
def test_disabled_channel_without_creds_emits_info(self, caplog):
"""Only an info log (no warning) is emitted when a channel is disabled with no credentials."""
+4
View File
@@ -2472,6 +2472,7 @@ class TestGatewayConformance:
mem_cfg.fact_confidence_threshold = 0.7
mem_cfg.injection_enabled = True
mem_cfg.max_injection_tokens = 2000
mem_cfg.token_counting = "tiktoken"
with patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg):
result = client.get_memory_config()
@@ -2479,6 +2480,7 @@ class TestGatewayConformance:
parsed = MemoryConfigResponse(**result)
assert parsed.enabled is True
assert parsed.max_facts == 100
assert parsed.token_counting == "tiktoken"
def test_get_memory_status(self, client):
mem_cfg = MagicMock()
@@ -2489,6 +2491,7 @@ class TestGatewayConformance:
mem_cfg.fact_confidence_threshold = 0.7
mem_cfg.injection_enabled = True
mem_cfg.max_injection_tokens = 2000
mem_cfg.token_counting = "tiktoken"
memory_data = {
"version": "1.0",
@@ -2514,6 +2517,7 @@ class TestGatewayConformance:
parsed = MemoryStatusResponse(**result)
assert parsed.config.enabled is True
assert parsed.config.token_counting == "tiktoken"
assert parsed.data.version == "1.0"
@@ -0,0 +1,45 @@
"""Regression test for the Docker Compose default Gateway worker count.
The Gateway holds run state (RunManager and the stream bridge) in process, so
the default deployment must run a single Uvicorn worker. Running more than one
worker without a shared cross-worker stream bridge breaks run cancellation, SSE
reconnects, request de-duplication, and IM channels (nginx has no sticky
sessions, so requests scatter across workers that each keep their own run
state). This test pins the safe default so it cannot silently regress to a
multi-worker default, while still allowing operators to override it once a
shared stream bridge exists.
"""
from __future__ import annotations
import re
from pathlib import Path
import yaml
REPO_ROOT = Path(__file__).resolve().parents[2]
COMPOSE_PATH = REPO_ROOT / "docker" / "docker-compose.yaml"
def _gateway_command() -> str:
"""Return the gateway service command as a single string."""
compose = yaml.safe_load(COMPOSE_PATH.read_text(encoding="utf-8"))
command = compose["services"]["gateway"]["command"]
# ``command`` may load as a scalar string or a list depending on YAML style.
if isinstance(command, list):
command = " ".join(str(part) for part in command)
return command
def test_gateway_defaults_to_single_worker():
"""With GATEWAY_WORKERS unset, the worker count must default to 1."""
command = _gateway_command()
match = re.search(r"GATEWAY_WORKERS:-(\d+)", command)
assert match is not None, f"gateway command must set a GATEWAY_WORKERS default; got: {command}"
assert match.group(1) == "1", f"default Gateway worker count must be 1, got {match.group(1)}"
def test_gateway_worker_count_remains_overridable():
"""The worker count must stay configurable, not hard-coded to 1."""
command = _gateway_command()
assert "${GATEWAY_WORKERS:-1}" in command, f"worker count must use ${{GATEWAY_WORKERS:-1}} so operators can override it; got: {command}"
+12
View File
@@ -233,3 +233,15 @@ def test_non_auth_mutation_rejects_mismatched_double_submit_token():
assert response.status_code == 403
assert response.json()["detail"] == "CSRF token mismatch."
def test_channel_posts_require_double_submit_csrf():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/channels/slack/connect",
headers={"Origin": "https://deerflow.example"},
)
assert response.status_code == 403
assert response.json()["detail"] == "CSRF token missing. Include X-CSRF-Token header."
+73
View File
@@ -203,6 +203,79 @@ class TestLoadAgentConfig:
assert cfg.name == "legacy-agent"
# ===========================================================================
# 3b. resolve_agent_dir — memory-only directory fallback (#3390)
# ===========================================================================
class TestResolveAgentDirMemoryOnlyFallback:
"""Regression tests for #3390.
When memory is enabled, the first conversation creates a user-isolated
agent directory containing only ``memory.json`` (no ``config.yaml``).
On the next turn ``resolve_agent_dir`` must fall through to the legacy
shared layout instead of returning the incomplete user directory.
"""
def test_user_dir_with_only_memory_falls_back_to_legacy(self, tmp_path):
"""User dir has memory.json but no config.yaml → use legacy dir."""
from deerflow.config.agents_config import resolve_agent_dir
# Legacy agent with full config
legacy_dir = tmp_path / "agents" / "my-agent"
legacy_dir.mkdir(parents=True)
(legacy_dir / "config.yaml").write_text("name: my-agent\n", encoding="utf-8")
(legacy_dir / "SOUL.md").write_text("legacy soul", encoding="utf-8")
# User dir created by memory write — no config.yaml
user_dir = tmp_path / "users" / "u1" / "agents" / "my-agent"
user_dir.mkdir(parents=True)
(user_dir / "memory.json").write_text("{}", encoding="utf-8")
with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.config.agents_config.get_effective_user_id", return_value="u1"):
result = resolve_agent_dir("my-agent", user_id="u1")
assert result == legacy_dir
def test_user_dir_with_config_takes_priority(self, tmp_path):
"""User dir with config.yaml should still win over legacy."""
from deerflow.config.agents_config import resolve_agent_dir
# Legacy
legacy_dir = tmp_path / "agents" / "my-agent"
legacy_dir.mkdir(parents=True)
(legacy_dir / "config.yaml").write_text("name: my-agent\n", encoding="utf-8")
# User dir with full config (migrated)
user_dir = tmp_path / "users" / "u1" / "agents" / "my-agent"
user_dir.mkdir(parents=True)
(user_dir / "config.yaml").write_text("name: my-agent\nmodel: gpt-4\n", encoding="utf-8")
(user_dir / "memory.json").write_text("{}", encoding="utf-8")
with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.config.agents_config.get_effective_user_id", return_value="u1"):
result = resolve_agent_dir("my-agent", user_id="u1")
assert result == user_dir
def test_load_config_falls_back_when_user_dir_is_memory_only(self, tmp_path):
"""End-to-end: load_agent_config works when user dir only has memory.json."""
config_dict = {"name": "my-agent", "description": "Legacy agent", "model": "deepseek-v3"}
_write_agent(tmp_path, "my-agent", config_dict)
# Simulate memory write creating user dir without config
user_dir = tmp_path / "users" / "u1" / "agents" / "my-agent"
user_dir.mkdir(parents=True)
(user_dir / "memory.json").write_text("{}", encoding="utf-8")
with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.config.agents_config.get_effective_user_id", return_value="u1"):
from deerflow.config.agents_config import load_agent_config
cfg = load_agent_config("my-agent", user_id="u1")
assert cfg.name == "my-agent"
assert cfg.model == "deepseek-v3"
# ===========================================================================
# 4. load_agent_soul
# ===========================================================================
@@ -0,0 +1,88 @@
"""Discord connection routing tests."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.channels.discord import DiscordChannel
from app.channels.message_bus import InboundMessage, MessageBus
@pytest.fixture
async def repo(tmp_path):
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'discord.db'}", sqlite_dir=str(tmp_path))
try:
yield ChannelConnectionRepository(
get_session_factory(),
cipher=ChannelCredentialCipher.from_key("discord-secret"),
)
finally:
await close_engine()
@pytest.mark.anyio
async def test_discord_inbound_attaches_owner_identity_from_user_level_connection(repo):
connection = await repo.upsert_connection(
owner_user_id="alice",
provider="discord",
external_account_id="987",
external_account_name="Alice",
status="connected",
)
channel = DiscordChannel(
bus=MessageBus(),
config={"bot_token": "discord-bot", "connection_repo": repo},
)
inbound = InboundMessage(
channel_name="discord",
chat_id="C123",
user_id="987",
text="hello",
)
attached = await channel._attach_connection_identity(inbound, guild_id="G123")
assert attached.connection_id == connection["id"]
assert attached.owner_user_id == "alice"
assert attached.workspace_id is None
@pytest.mark.anyio
async def test_discord_connect_command_binds_gateway_identity(repo):
state = "discord-bind-code"
await repo.create_oauth_state(
owner_user_id="deerflow-user-1",
provider="discord",
state=state,
expires_at=datetime.now(UTC) + timedelta(minutes=5),
)
channel = DiscordChannel(
bus=MessageBus(),
config={"bot_token": "discord-bot", "connection_repo": repo},
)
message = MagicMock()
message.author.id = 987
message.author.display_name = "Alice"
message.guild.id = 123
message.guild.name = "Deer Guild"
message.channel.id = 456
message.channel.send = AsyncMock()
handled = await channel._bind_connection_from_connect_code(message, state)
connections = await repo.list_connections("deerflow-user-1")
assert handled is True
assert len(connections) == 1
assert connections[0]["provider"] == "discord"
assert connections[0]["external_account_id"] == "987"
assert connections[0]["external_account_name"] == "Alice"
assert connections[0]["workspace_id"] == "123"
assert connections[0]["workspace_name"] == "Deer Guild"
assert connections[0]["metadata"]["channel_id"] == "456"
message.channel.send.assert_awaited_once()
+25
View File
@@ -73,6 +73,31 @@ def test_feishu_on_message_plain_text():
assert mock_make_inbound.call_args[1]["text"] == "Hello world"
def test_feishu_is_not_running_when_ws_thread_exits():
bus = MessageBus()
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
channel._running = True
channel._thread = MagicMock()
channel._thread.is_alive.return_value = False
assert channel.is_running is False
def test_feishu_event_handler_ignores_non_content_message_events():
import lark_oapi as lark
bus = MessageBus()
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
event_handler = channel._build_event_handler(lark)
assert "p2.im.message.receive_v1" in event_handler._processorMap
assert "p2.im.message.message_read_v1" in event_handler._processorMap
assert "p2.im.message.reaction.created_v1" in event_handler._processorMap
assert "p2.im.message.reaction.deleted_v1" in event_handler._processorMap
assert "p2.im.message.recalled_v1" in event_handler._processorMap
def test_feishu_on_message_rich_text():
bus = MessageBus()
config = {"app_id": "test", "app_secret": "test"}
+89
View File
@@ -4,6 +4,18 @@ from __future__ import annotations
import json
import pytest
from deerflow.config.app_config import AppConfig, reset_app_config, set_app_config
@pytest.fixture
def _stub_app_config():
"""Keep run-context tests independent from a developer-local config.yaml."""
set_app_config(AppConfig.model_validate({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}))
yield
reset_app_config()
def test_format_sse_basic():
from app.gateway.services import format_sse
@@ -474,6 +486,83 @@ def test_inject_authenticated_user_context_skips_internal_role():
assert config["context"]["user_id"] == "channel-user-7"
def test_start_run_uses_internal_owner_header_for_persistence(_stub_app_config):
import asyncio
from types import SimpleNamespace
from unittest.mock import patch
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.store.memory import InMemoryStore
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
from app.gateway.services import start_run
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
from deerflow.runtime import RunManager
from deerflow.runtime.runs.store.memory import MemoryRunStore
from deerflow.runtime.user_context import get_effective_user_id
async def _scenario():
run_store = MemoryRunStore()
thread_store = MemoryThreadMetaStore(InMemoryStore())
await thread_store.create("channel-thread", user_id="default", metadata={"legacy": True})
run_manager = RunManager(store=run_store)
state = SimpleNamespace(
stream_bridge=SimpleNamespace(),
run_manager=run_manager,
checkpointer=InMemorySaver(),
store=InMemoryStore(),
run_event_store=SimpleNamespace(),
run_events_config=None,
thread_store=thread_store,
)
request = SimpleNamespace(
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "owner-1"},
state=SimpleNamespace(user=SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)),
app=SimpleNamespace(state=state),
)
body = SimpleNamespace(
assistant_id="lead_agent",
input={"messages": [{"role": "human", "content": "hi"}]},
metadata={},
config=None,
context=None,
on_disconnect="cancel",
multitask_strategy="reject",
stream_mode=None,
stream_subgraphs=False,
interrupt_before=None,
interrupt_after=None,
)
task_context: dict[str, str] = {}
async def fake_run_agent(*args, **kwargs):
task_context["user_id"] = get_effective_user_id()
with (
patch("app.gateway.services.resolve_agent_factory", return_value=object()),
patch("app.gateway.services.run_agent", side_effect=fake_run_agent),
):
record = await start_run(body, "channel-thread", request)
await record.task
owner_run = await run_store.get(record.run_id, user_id="owner-1")
default_run = await run_store.get(record.run_id, user_id="default")
owner_thread = await thread_store.get("channel-thread", user_id="owner-1")
default_thread = await thread_store.get("channel-thread", user_id="default")
return owner_run, default_run, owner_thread, default_thread, task_context
owner_run, default_run, owner_thread, default_thread, task_context = asyncio.run(_scenario())
assert owner_run is not None
assert owner_run["user_id"] == "owner-1"
assert default_run is None
assert owner_thread is not None
assert owner_thread["user_id"] == "owner-1"
assert owner_thread["metadata"] == {"legacy": True}
assert default_thread is None
assert task_context["user_id"] == "owner-1"
# ---------------------------------------------------------------------------
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
# ---------------------------------------------------------------------------
+15
View File
@@ -33,3 +33,18 @@ def test_internal_auth_generates_process_local_fallback(monkeypatch):
assert reloaded.is_valid_internal_auth_token(token) is True
finally:
importlib.reload(reloaded)
def test_internal_auth_headers_can_carry_owner_user_id(monkeypatch):
import app.gateway.internal_auth as internal_auth
monkeypatch.setenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", "shared-token")
reloaded = importlib.reload(internal_auth)
try:
headers = reloaded.create_internal_auth_headers(owner_user_id="owner-1")
assert headers[reloaded.INTERNAL_AUTH_HEADER_NAME] == "shared-token"
assert headers[reloaded.INTERNAL_OWNER_USER_ID_HEADER_NAME] == "owner-1"
finally:
monkeypatch.delenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", raising=False)
importlib.reload(reloaded)
+9
View File
@@ -21,6 +21,7 @@ from langgraph_sdk import Auth
from app.gateway.auth.config import AuthConfig, set_auth_config
from app.gateway.auth.jwt import create_access_token, decode_token
from app.gateway.auth.models import User
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID
from app.gateway.langgraph_auth import add_owner_filter, authenticate
# ── Helpers ───────────────────────────────────────────────────────────────
@@ -59,6 +60,14 @@ def test_no_cookie_raises_401():
assert "Not authenticated" in str(exc.value.detail)
def test_auth_disabled_skips_csrf_and_authenticates_e2e_user(monkeypatch):
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
identity = asyncio.run(authenticate(_req(method="POST")))
assert identity == AUTH_DISABLED_USER_ID
def test_invalid_jwt_raises_401():
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": "garbage"})))
+4 -2
View File
@@ -192,7 +192,7 @@ def test_build_acp_section_uses_explicit_app_config_without_global_config(monkey
def test_get_memory_context_uses_explicit_app_config_without_global_config(monkeypatch):
explicit_config = SimpleNamespace(
memory=SimpleNamespace(enabled=True, injection_enabled=True, max_injection_tokens=1234),
memory=SimpleNamespace(enabled=True, injection_enabled=True, max_injection_tokens=1234, token_counting="tiktoken"),
)
captured: dict[str, object] = {}
@@ -204,9 +204,10 @@ def test_get_memory_context_uses_explicit_app_config_without_global_config(monke
captured["user_id"] = user_id
return {"facts": []}
def fake_format_memory_for_injection(memory_data, *, max_tokens):
def fake_format_memory_for_injection(memory_data, *, max_tokens, use_tiktoken=True):
captured["memory_data"] = memory_data
captured["max_tokens"] = max_tokens
captured["use_tiktoken"] = use_tiktoken
return "remember this"
monkeypatch.setattr("deerflow.config.memory_config.get_memory_config", fail_get_memory_config)
@@ -223,6 +224,7 @@ def test_get_memory_context_uses_explicit_app_config_without_global_config(monke
"user_id": "user-1",
"memory_data": {"facts": []},
"max_tokens": 1234,
"use_tiktoken": True,
}
@@ -612,6 +612,54 @@ class TestLocalSandboxProviderMounts:
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
def test_setup_path_mappings_logs_actionable_error_for_missing_host_path(self, tmp_path, caplog):
"""Regression for #3244.
When ``sandbox.mounts[].host_path`` is absent from the gateway process's
filesystem (the typical symptom in Docker production mode: host_path is a
host machine path that is not bind-mounted into the gateway container),
the mount is still skipped but the failure must be a hard-to-miss ERROR
log with explicit, actionable guidance about Docker bind mounts, not the
old DEBUG/WARNING that buried the silent failure.
"""
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
missing_host_path = tmp_path / "does-not-exist"
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
sandbox_config = SandboxConfig(
use="deerflow.sandbox.local:LocalSandboxProvider",
mounts=[
VolumeMountConfig(host_path=str(missing_host_path), container_path="/mnt/knowledge", read_only=True),
],
)
config = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir, use="deerflow.skills.storage.local_skill_storage:LocalSkillStorage"),
sandbox=sandbox_config,
)
with caplog.at_level("ERROR", logger="deerflow.sandbox.local.local_sandbox_provider"):
with patch("deerflow.config.get_app_config", return_value=config):
provider = LocalSandboxProvider()
# Silent-skip behaviour is preserved (no breaking change for existing deployments).
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
# The failure must be observable at ERROR level and reference the offending paths.
error_records = [r for r in caplog.records if r.levelname == "ERROR"]
assert error_records, "expected an ERROR log when host_path is missing"
message = "\n".join(r.getMessage() for r in error_records)
assert str(missing_host_path) in message
assert "/mnt/knowledge" in message
# And it must include actionable Docker guidance so users don't lose hours
# to a silent empty-mount failure in production.
lowered = message.lower()
assert "docker" in lowered
assert "gateway" in lowered
assert "docker-compose" in lowered
def test_write_file_resolves_container_paths_in_content(self, tmp_path):
"""write_file should replace container paths in file content with local paths."""
data_dir = tmp_path / "data"
@@ -39,7 +39,7 @@ def test_format_memory_sorts_facts_by_confidence_desc() -> None:
def test_format_memory_respects_budget_when_adding_facts(monkeypatch) -> None:
# Make token counting deterministic for this test by counting characters.
monkeypatch.setattr("deerflow.agents.memory.prompt._count_tokens", lambda text, encoding_name="cl100k_base": len(text))
monkeypatch.setattr("deerflow.agents.memory.prompt._count_tokens", lambda text, encoding_name="cl100k_base", *, use_tiktoken=True: len(text))
memory_data = {
"user": {},
@@ -1,5 +1,6 @@
"""Tests for user-scoped path resolution in Paths."""
import logging
from pathlib import Path
import pytest
@@ -44,6 +45,7 @@ class TestMakeSafeUserId:
# Sanitized prefix plus a stable digest of the original.
assert result.startswith("user-example-com-")
assert len(result.rsplit("-", 1)[1]) == 16
assert result == "user-example-com-b4c9a289323b21a0"
assert make_safe_user_id("user@example.com") == result
def test_sanitized_id_passes_validation(self, paths: Paths):
@@ -69,6 +71,37 @@ class TestUserDir:
def test_user_dir(self, paths: Paths):
assert paths.user_dir("alice") == paths.base_dir / "users" / "alice"
def test_prepare_user_dir_migrates_unique_legacy_unsafe_bucket(self, paths: Paths):
from deerflow.config.paths import make_safe_user_id
raw = "user@example.com"
safe = make_safe_user_id(raw)
legacy_dir = paths.base_dir / "users" / "user-example-com-63a710569261a24b"
legacy_dir.mkdir(parents=True)
(legacy_dir / "memory.json").write_text('{"legacy": true}\n', encoding="utf-8")
assert paths.prepare_user_dir_for_raw_id(raw) == safe
current_dir = paths.user_dir(safe)
assert current_dir.exists()
assert not legacy_dir.exists()
assert (current_dir / "memory.json").read_text(encoding="utf-8") == '{"legacy": true}\n'
def test_prepare_user_dir_skips_ambiguous_legacy_unsafe_buckets(self, paths: Paths, caplog):
from deerflow.config.paths import make_safe_user_id
users_dir = paths.base_dir / "users"
(users_dir / "a-b-1111111111111111").mkdir(parents=True)
(users_dir / "a-b-2222222222222222").mkdir(parents=True)
caplog.set_level(logging.WARNING)
assert paths.prepare_user_dir_for_raw_id("a.b") == make_safe_user_id("a.b")
assert not paths.user_dir(make_safe_user_id("a.b")).exists()
assert (users_dir / "a-b-1111111111111111").exists()
assert (users_dir / "a-b-2222222222222222").exists()
assert any("Multiple legacy unsafe-id user directories matched" in r.message for r in caplog.records)
class TestUserMemoryFile:
def test_user_memory_file(self, paths: Paths):
+26 -2
View File
@@ -257,14 +257,38 @@ def test_provisioner_is_alive_true_only_when_status_running(monkeypatch):
assert backend._provisioner_is_alive("abc123") is False
def test_provisioner_is_alive_returns_false_on_request_exception(monkeypatch):
def test_provisioner_is_alive_returns_false_on_404(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
return _StubResponse(status_code=404)
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_is_alive("abc123") is False
def test_provisioner_is_alive_raises_on_request_exception(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
raise requests.RequestException("boom")
monkeypatch.setattr(requests, "get", mock_get)
assert backend._provisioner_is_alive("abc123") is False
with pytest.raises(RuntimeError, match="Provisioner health check failed for abc123"):
backend._provisioner_is_alive("abc123")
def test_provisioner_is_alive_raises_on_server_error(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_get(url: str, timeout: int):
response = _StubResponse(status_code=503)
response.text = "unavailable"
return response
monkeypatch.setattr(requests, "get", mock_get)
with pytest.raises(RuntimeError, match="HTTP 503 unavailable"):
backend._provisioner_is_alive("abc123")
def test_discover_delegates_to_provisioner_discover(monkeypatch):
+4 -3
View File
@@ -179,15 +179,16 @@ class TestLifecycleCallbacks:
assert "run.end" in types
@pytest.mark.anyio
async def test_nested_chain_no_run_start(self, journal_setup):
"""Nested chains (parent_run_id set) should NOT produce run.start."""
async def test_nested_chain_no_run_lifecycle_events(self, journal_setup):
"""Nested chains (parent_run_id set) should NOT produce root run lifecycle events."""
j, store = journal_setup
parent_id = uuid4()
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id)
j.on_chain_end({}, run_id=uuid4())
j.on_chain_end({}, run_id=uuid4(), parent_run_id=parent_id)
await j.flush()
events = await store.list_events("t1", "r1")
assert not any(e["event_type"] == "run.start" for e in events)
assert not any(e["event_type"] == "run.end" for e in events)
class TestToolCallbacks:
+183
View File
@@ -5,7 +5,10 @@ import asyncio
import pytest
from langchain.agents.middleware import AgentMiddleware
from langchain.tools import ToolRuntime
from langchain_core.messages import ToolMessage
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.runtime import Runtime
from langgraph.types import Command
from deerflow.sandbox.middleware import SandboxMiddleware
from deerflow.sandbox.sandbox import Sandbox
@@ -223,3 +226,183 @@ async def test_aafter_agent_delegates_to_super_when_no_sandbox(monkeypatch: pyte
assert result == {"delegated": True}
assert calls == [(state, runtime)]
# ---------------------------------------------------------------------------
# wrap_tool_call / awrap_tool_call: persistent sandbox state via Command
# ---------------------------------------------------------------------------
def _make_tool_call_request(state: dict) -> ToolCallRequest:
"""Build a minimal ToolCallRequest backed by a real ToolRuntime."""
runtime = ToolRuntime(
state=state,
context={},
config={"configurable": {}},
stream_writer=lambda _: None,
tools=[],
tool_call_id="call-1",
store=None,
)
return ToolCallRequest(
tool_call={"id": "call-1", "name": "bash", "args": {}},
tool=None,
state=state,
runtime=runtime,
)
def test_wrap_tool_call_emits_command_when_lazy_init_happens() -> None:
middleware = SandboxMiddleware()
state: dict = {}
request = _make_tool_call_request(state)
def handler(req: ToolCallRequest) -> ToolMessage:
# Simulate ensure_sandbox_initialized() mutating runtime.state in-place.
req.runtime.state["sandbox"] = {"sandbox_id": "new-sandbox"}
return ToolMessage(content="ok", tool_call_id="call-1", name="bash")
result = middleware.wrap_tool_call(request, handler)
assert isinstance(result, Command)
assert isinstance(result.update, dict)
assert result.update["sandbox"] == {"sandbox_id": "new-sandbox"}
messages = result.update["messages"]
assert len(messages) == 1
assert messages[0].content == "ok"
assert messages[0].tool_call_id == "call-1"
def test_wrap_tool_call_passthrough_when_sandbox_already_in_state() -> None:
middleware = SandboxMiddleware()
state: dict = {"sandbox": {"sandbox_id": "existing"}}
request = _make_tool_call_request(state)
original = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
def handler(req: ToolCallRequest) -> ToolMessage:
return original
result = middleware.wrap_tool_call(request, handler)
assert result is original
def test_wrap_tool_call_passthrough_when_handler_did_not_initialize_sandbox() -> None:
middleware = SandboxMiddleware()
state: dict = {}
request = _make_tool_call_request(state)
original = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
def handler(req: ToolCallRequest) -> ToolMessage:
return original
result = middleware.wrap_tool_call(request, handler)
assert result is original
def test_wrap_tool_call_merges_with_existing_command_update() -> None:
middleware = SandboxMiddleware()
state: dict = {}
request = _make_tool_call_request(state)
tool_msg = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
def handler(req: ToolCallRequest) -> Command:
req.runtime.state["sandbox"] = {"sandbox_id": "new-sandbox"}
return Command(
update={
"messages": [tool_msg],
"viewed_images": {"a.png": {"base64": "x", "mime_type": "image/png"}},
},
goto="next-node",
)
result = middleware.wrap_tool_call(request, handler)
assert isinstance(result, Command)
assert result.goto == "next-node"
assert isinstance(result.update, dict)
assert result.update["messages"] == [tool_msg]
assert result.update["viewed_images"] == {"a.png": {"base64": "x", "mime_type": "image/png"}}
assert result.update["sandbox"] == {"sandbox_id": "new-sandbox"}
def test_wrap_tool_call_does_not_override_non_dict_update() -> None:
middleware = SandboxMiddleware()
state: dict = {}
request = _make_tool_call_request(state)
cmd = Command(update=[("messages", [ToolMessage(content="x", tool_call_id="c", name="bash")])])
def handler(req: ToolCallRequest) -> Command:
req.runtime.state["sandbox"] = {"sandbox_id": "new-sandbox"}
return cmd
result = middleware.wrap_tool_call(request, handler)
# Non-dict update is left untouched to avoid silent data loss.
assert result is cmd
@pytest.mark.anyio
async def test_awrap_tool_call_emits_command_when_lazy_init_happens() -> None:
middleware = SandboxMiddleware()
state: dict = {}
request = _make_tool_call_request(state)
async def handler(req: ToolCallRequest) -> ToolMessage:
req.runtime.state["sandbox"] = {"sandbox_id": "async-new"}
return ToolMessage(content="ok", tool_call_id="call-1", name="bash")
result = await middleware.awrap_tool_call(request, handler)
assert isinstance(result, Command)
assert isinstance(result.update, dict)
assert result.update["sandbox"] == {"sandbox_id": "async-new"}
messages = result.update["messages"]
assert len(messages) == 1
assert messages[0].content == "ok"
@pytest.mark.anyio
async def test_awrap_tool_call_passthrough_when_sandbox_already_in_state() -> None:
middleware = SandboxMiddleware()
state: dict = {"sandbox": {"sandbox_id": "existing"}}
request = _make_tool_call_request(state)
original = ToolMessage(content="ok", tool_call_id="call-1", name="bash")
async def handler(req: ToolCallRequest) -> ToolMessage:
return original
result = await middleware.awrap_tool_call(request, handler)
assert result is original
def test_wrap_tool_call_preserves_existing_command_fields_when_merging() -> None:
"""Regression: when merging sandbox_update into an existing Command,
all other Command fields (e.g. graph, goto, resume) must be preserved.
"""
middleware = SandboxMiddleware()
state: dict = {}
request = _make_tool_call_request(state)
def handler(req: ToolCallRequest) -> Command:
req.runtime.state["sandbox"] = {"sandbox_id": "sbx-merge"}
return Command(
update={"existing_key": "existing_value"},
graph="parent",
goto="next_node",
resume="resume-token",
)
result = middleware.wrap_tool_call(request, handler)
assert isinstance(result, Command)
assert result.update == {
"existing_key": "existing_value",
"sandbox": {"sandbox_id": "sbx-merge"},
}
# Critical: other Command fields must NOT be dropped by the merge.
assert result.graph == "parent"
assert result.goto == "next_node"
assert result.resume == "resume-token"
+75
View File
@@ -7,7 +7,9 @@ Run from repo root:
from __future__ import annotations
import yaml
from wizard import ui as wizard_ui
from wizard.providers import LLM_PROVIDERS, SEARCH_PROVIDERS, WEB_FETCH_PROVIDERS, LLMProvider
from wizard.steps import channels as channels_step
from wizard.steps import llm as llm_step
from wizard.steps import search as search_step
from wizard.writer import (
@@ -327,6 +329,44 @@ class TestBuildMinimalConfig:
assert model["when_thinking_enabled"]["extra_body"]["thinking"]["type"] == "enabled"
assert model["when_thinking_disabled"]["extra_body"]["thinking"]["type"] == "disabled"
def test_can_enable_selected_channel_connections(self):
content = build_minimal_config(
provider_use="langchain_openai:ChatOpenAI",
model_name="gpt-4o",
display_name="OpenAI",
api_key_field="api_key",
env_var="OPENAI_API_KEY",
channel_connection_providers=["feishu", "slack"],
)
data = yaml.safe_load(content)
channel_connections = data["channel_connections"]
assert channel_connections["enabled"] is True
assert channel_connections["feishu"]["enabled"] is True
assert channel_connections["slack"]["enabled"] is True
assert channel_connections["telegram"]["enabled"] is False
assert channel_connections["discord"]["enabled"] is False
assert channel_connections["dingtalk"]["enabled"] is False
assert channel_connections["wechat"]["enabled"] is False
assert channel_connections["wecom"]["enabled"] is False
def test_channel_connections_disabled_when_no_channels_selected(self):
content = build_minimal_config(
provider_use="langchain_openai:ChatOpenAI",
model_name="gpt-4o",
display_name="OpenAI",
api_key_field="api_key",
env_var="OPENAI_API_KEY",
channel_connection_providers=[],
)
data = yaml.safe_load(content)
channel_connections = data["channel_connections"]
assert channel_connections["enabled"] is False
assert all(not config["enabled"] for provider, config in channel_connections.items() if provider != "enabled")
class TestLLMStep:
def test_model_selection_defaults_to_provider_default_model(self, monkeypatch):
@@ -384,6 +424,41 @@ class TestLLMStep:
assert result.base_url == "https://gateway.example/v1"
class TestChannelsStep:
def test_returns_selected_channel_keys(self, monkeypatch):
monkeypatch.setattr(channels_step, "print_header", lambda *_args, **_kwargs: None)
monkeypatch.setattr(channels_step, "print_info", lambda *_args, **_kwargs: None)
monkeypatch.setattr(channels_step, "print_success", lambda *_args, **_kwargs: None)
monkeypatch.setattr(channels_step, "ask_multi_choice", lambda *_args, **_kwargs: [0, 3, 6])
result = channels_step.run_channels_step()
assert result.enabled_providers == ["telegram", "feishu", "wecom"]
def test_empty_selection_disables_channel_connections(self, monkeypatch):
monkeypatch.setattr(channels_step, "print_header", lambda *_args, **_kwargs: None)
monkeypatch.setattr(channels_step, "print_info", lambda *_args, **_kwargs: None)
monkeypatch.setattr(channels_step, "print_success", lambda *_args, **_kwargs: None)
monkeypatch.setattr(channels_step, "ask_multi_choice", lambda *_args, **_kwargs: [])
result = channels_step.run_channels_step()
assert result.enabled_providers == []
class TestWizardUi:
def test_multi_choice_blank_requires_input_without_default(self, monkeypatch):
answers = iter(["", "2"])
monkeypatch.setattr("builtins.input", lambda _prompt: next(answers))
assert wizard_ui.ask_multi_choice("Pick", ["First", "Second"], default=None) == [1]
def test_multi_choice_blank_accepts_empty_default(self, monkeypatch):
monkeypatch.setattr("builtins.input", lambda _prompt: "")
assert wizard_ui.ask_multi_choice("Pick", ["First", "Second"], default=[]) == []
# ---------------------------------------------------------------------------
# writer.py — env file helpers
# ---------------------------------------------------------------------------
@@ -0,0 +1,154 @@
"""Slack connection tests for user-owned channel bindings."""
from __future__ import annotations
import sys
from datetime import UTC, datetime, timedelta
from types import ModuleType
from unittest.mock import AsyncMock, MagicMock
from app.channels.message_bus import MessageBus, OutboundMessage
async def _make_repo(tmp_path):
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
from deerflow.persistence.engine import get_session_factory, init_engine
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'slack.db'}", sqlite_dir=str(tmp_path))
return ChannelConnectionRepository(
get_session_factory(),
cipher=ChannelCredentialCipher.from_key("slack-secret"),
)
def test_slack_connect_command_binds_socket_mode_identity(tmp_path):
import anyio
from app.channels.slack import SlackChannel
async def go():
repo = await _make_repo(tmp_path)
state = "slack-bind-code"
await repo.create_oauth_state(
owner_user_id="deerflow-user-1",
provider="slack",
state=state,
expires_at=datetime.now(UTC) + timedelta(minutes=5),
)
channel = SlackChannel(
bus=MessageBus(),
config={"bot_token": "xoxb-operator", "app_token": "xapp-operator", "connection_repo": repo},
)
channel._web_client = MagicMock()
handled = await channel._bind_connection_from_connect_code(
event={
"user": "U123",
"channel": "C123",
"ts": "1710000000.000100",
},
team_id="T123",
code=state,
)
connections = await repo.list_connections("deerflow-user-1")
assert handled is True
assert len(connections) == 1
assert connections[0]["provider"] == "slack"
assert connections[0]["external_account_id"] == "U123"
assert connections[0]["workspace_id"] == "T123"
assert connections[0]["metadata"]["channel_id"] == "C123"
channel._web_client.chat_postMessage.assert_called_once()
await repo.close()
anyio.run(go)
def test_slack_send_uses_connection_bot_token_when_connection_id_is_present():
import anyio
from app.channels.slack import SlackChannel
async def go():
repo = AsyncMock()
repo.get_credentials.return_value = {"access_token": "xoxb-connection-token"}
web_client = MagicMock()
web_client_factory = MagicMock(return_value=web_client)
channel = SlackChannel(
bus=MessageBus(),
config={
"connection_repo": repo,
"web_client_factory": web_client_factory,
},
)
msg = OutboundMessage(
channel_name="slack",
chat_id="C123",
thread_id="thread-1",
text="hello",
connection_id="connection-1",
)
await channel.send(msg)
repo.get_credentials.assert_awaited_once_with("connection-1")
web_client_factory.assert_called_once_with(token="xoxb-connection-token")
web_client.chat_postMessage.assert_called_once()
anyio.run(go)
def test_slack_http_events_mode_initializes_operator_web_client(monkeypatch):
import anyio
from app.channels.slack import SlackChannel
class FakeWebClient:
def __init__(self, token: str) -> None:
self.token = token
self.messages: list[dict] = []
def auth_test(self):
return {"user_id": "B-http"}
def chat_postMessage(self, **kwargs):
self.messages.append(kwargs)
slack_sdk = ModuleType("slack_sdk")
slack_sdk.WebClient = FakeWebClient
socket_mode = ModuleType("slack_sdk.socket_mode")
socket_mode.SocketModeClient = object
response = ModuleType("slack_sdk.socket_mode.response")
response.SocketModeResponse = object
monkeypatch.setitem(sys.modules, "slack_sdk", slack_sdk)
monkeypatch.setitem(sys.modules, "slack_sdk.socket_mode", socket_mode)
monkeypatch.setitem(sys.modules, "slack_sdk.socket_mode.response", response)
async def go():
channel = SlackChannel(
bus=MessageBus(),
config={
"bot_token": "xoxb-operator",
"event_delivery": "http",
"connection_repo": MagicMock(),
},
)
await channel.start()
assert channel._running is True
assert channel._web_client is not None
assert channel._web_client.token == "xoxb-operator"
assert channel._bot_user_id == "B-http"
channel._post_connection_reply("C123", "Slack connected to DeerFlow.", "1710000000.000100")
assert channel._web_client.messages == [
{
"channel": "C123",
"text": "Slack connected to DeerFlow.",
"thread_ts": "1710000000.000100",
}
]
await channel.stop()
anyio.run(go)
@@ -0,0 +1,173 @@
"""Cross-user isolation for the stateless ``POST /api/runs/stream`` and ``/wait`` endpoints.
These endpoints receive ``thread_id`` in the request body, so the
``@require_permission(owner_check=True)`` decorator which reads the
``thread_id`` *path* parameter cannot protect them. The owner check
lives inside ``services.start_run()`` instead; this suite pins it at the
HTTP layer so the gap cannot silently reopen.
Strategy
--------
``app.state.run_manager.create_or_reject`` raises ``ConflictError``, so a
request that *passes* the owner check deterministically short-circuits
with 409 before any agent code runs. The two outcomes:
- 404 + ``create_or_reject`` never awaited -> blocked by the owner check
- 409 + ``create_or_reject`` awaited -> passed the owner check
The thread store is a real ``MemoryThreadMetaStore`` (not a mock) so the
``check_access`` semantics under test missing row allows, ``user_id``
NULL allows, foreign owner denies are exercised through real code.
"""
from __future__ import annotations
import asyncio
from contextlib import contextmanager
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
from _router_auth_helpers import make_authed_test_app
from fastapi.testclient import TestClient
from langgraph.store.memory import InMemoryStore
from app.gateway.auth.models import User
from app.gateway.routers import runs
from deerflow.config.app_config import AppConfig, reset_app_config, set_app_config
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
from deerflow.runtime import ConflictError
USER_A = User(email="owner-a@example.com", password_hash="x", system_role="user", id=uuid4())
USER_B = User(email="intruder-b@example.com", password_hash="x", system_role="user", id=uuid4())
INTERNAL_USER = SimpleNamespace(id="default", system_role="internal")
THREAD_A = "thread-owned-by-a"
THREAD_SHARED = "thread-shared-null-owner"
@pytest.fixture(autouse=True)
def _stub_app_config():
"""Inject a minimal AppConfig so the allowed path (which builds a
RunContext via ``get_config()``) never reads config.yaml from disk."""
set_app_config(AppConfig.model_validate({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}))
yield
reset_app_config()
def _make_thread_store() -> MemoryThreadMetaStore:
store = MemoryThreadMetaStore(InMemoryStore())
async def _seed():
await store.create(THREAD_A, user_id=str(USER_A.id))
await store.create(THREAD_SHARED, user_id=None)
asyncio.run(_seed())
return store
@contextmanager
def _client(user):
"""Yield a ``TestClient`` authenticated as ``user`` plus the stubbed
``create_or_reject`` mock, closing the client (and its anyio portal /
background threads) on exit.
``create_or_reject`` raises ``ConflictError`` so a request that passes the
owner check short-circuits to 409 before any agent code runs.
"""
app = make_authed_test_app(user_factory=lambda: user)
app.include_router(runs.router)
app.state.thread_store = _make_thread_store()
app.state.stream_bridge = MagicMock()
app.state.checkpointer = MagicMock()
app.state.store = MagicMock()
app.state.run_events_config = None
app.state.run_event_store = MagicMock()
run_manager = MagicMock()
run_manager.create_or_reject = AsyncMock(side_effect=ConflictError("sentinel: owner check passed"))
app.state.run_manager = run_manager
with TestClient(app) as client:
yield client, run_manager.create_or_reject
def _body(thread_id: str | None = None) -> dict:
if thread_id is None:
return {}
return {"config": {"configurable": {"thread_id": thread_id}}}
# ---------------------------------------------------------------------------
# Denied: another user's thread
# ---------------------------------------------------------------------------
def test_stream_cross_user_returns_404():
"""User B cannot start a run on user A's thread via /api/runs/stream."""
with _client(USER_B) as (client, create_or_reject):
response = client.post("/api/runs/stream", json=_body(THREAD_A))
assert response.status_code == 404
assert response.json()["detail"] == f"Thread {THREAD_A} not found"
create_or_reject.assert_not_awaited()
def test_wait_cross_user_returns_404_without_channel_values():
"""User B cannot read user A's checkpoint state via /api/runs/wait."""
with _client(USER_B) as (client, create_or_reject):
response = client.post("/api/runs/wait", json=_body(THREAD_A))
assert response.status_code == 404
assert response.json() == {"detail": f"Thread {THREAD_A} not found"}
create_or_reject.assert_not_awaited()
# ---------------------------------------------------------------------------
# Allowed: owner, fresh/untracked/shared threads, internal role
# ---------------------------------------------------------------------------
def test_stream_owner_passes_owner_check():
"""User A reaches run creation on their own thread (409 sentinel)."""
with _client(USER_A) as (client, create_or_reject):
response = client.post("/api/runs/stream", json=_body(THREAD_A))
assert response.status_code == 409
create_or_reject.assert_awaited()
def test_wait_owner_passes_owner_check():
with _client(USER_A) as (client, create_or_reject):
response = client.post("/api/runs/wait", json=_body(THREAD_A))
assert response.status_code == 409
create_or_reject.assert_awaited()
def test_stream_without_thread_id_passes_owner_check():
"""Stateless run with no thread_id auto-creates a thread — never blocked."""
with _client(USER_B) as (client, create_or_reject):
response = client.post("/api/runs/stream", json=_body())
assert response.status_code == 409
create_or_reject.assert_awaited()
def test_stream_untracked_thread_passes_owner_check():
"""A thread_id with no thread_meta row (untracked legacy) stays accessible."""
with _client(USER_B) as (client, create_or_reject):
response = client.post("/api/runs/stream", json=_body("never-created-thread"))
assert response.status_code == 409
create_or_reject.assert_awaited()
def test_stream_shared_thread_passes_owner_check():
"""A thread_meta row with user_id NULL (shared / pre-auth data) stays accessible."""
with _client(USER_B) as (client, create_or_reject):
response = client.post("/api/runs/stream", json=_body(THREAD_SHARED))
assert response.status_code == 409
create_or_reject.assert_awaited()
def test_stream_internal_role_bypasses_owner_check():
"""IM channels run with the internal system role on behalf of platform
users whose threads they do not own the owner check must not break them."""
with _client(INTERNAL_USER) as (client, create_or_reject):
response = client.post("/api/runs/stream", json=_body(THREAD_A))
assert response.status_code == 409
create_or_reject.assert_awaited()
@@ -0,0 +1,100 @@
"""Tests for Telegram deep-link channel connections."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.channels.message_bus import MessageBus
from app.channels.telegram import TelegramChannel
@pytest.fixture
async def repo(tmp_path: Path):
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'telegram.db'}", sqlite_dir=str(tmp_path))
try:
yield ChannelConnectionRepository(
get_session_factory(),
cipher=ChannelCredentialCipher.from_key("telegram-secret"),
)
finally:
await close_engine()
def _telegram_update(*, text: str = "/start", user_id: int = 42, chat_id: int = 100, chat_type: str = "private"):
update = MagicMock()
update.effective_user.id = user_id
update.effective_user.username = "alice"
update.effective_user.full_name = "Alice Example"
update.effective_chat.id = chat_id
update.effective_chat.type = chat_type
update.message.text = text
update.message.message_id = 55
update.message.reply_to_message = None
update.message.reply_text = AsyncMock()
return update
@pytest.mark.anyio
async def test_start_with_deep_link_state_binds_telegram_chat(repo):
state = "telegram-bind-state"
await repo.create_oauth_state(
owner_user_id="deerflow-user-1",
provider="telegram",
state=state,
expires_at=datetime.now(UTC) + timedelta(minutes=5),
)
channel = TelegramChannel(
bus=MessageBus(),
config={"bot_token": "test-token", "connection_repo": repo},
)
update = _telegram_update(text=f"/start {state}")
context = MagicMock()
context.args = [state]
await channel._cmd_start(update, context)
connections = await repo.list_connections("deerflow-user-1")
assert len(connections) == 1
assert connections[0]["provider"] == "telegram"
assert connections[0]["external_account_id"] == "42"
assert connections[0]["external_account_name"] == "Alice Example"
assert connections[0]["workspace_id"] == "100"
assert connections[0]["metadata"]["chat_type"] == "private"
update.message.reply_text.assert_awaited_once()
assert "connected" in update.message.reply_text.await_args.args[0].lower()
@pytest.mark.anyio
async def test_bound_telegram_message_publishes_connection_identity(repo):
connection = await repo.upsert_connection(
owner_user_id="deerflow-user-1",
provider="telegram",
external_account_id="42",
external_account_name="Alice Example",
workspace_id="100",
metadata={"chat_type": "private"},
)
bus = MessageBus()
channel = TelegramChannel(
bus=bus,
config={"bot_token": "test-token", "connection_repo": repo},
)
channel._main_loop = __import__("asyncio").get_event_loop()
channel._send_running_reply = AsyncMock()
await channel._on_text(_telegram_update(text="hello"), None)
inbound = await bus.get_inbound()
assert inbound.connection_id == connection["id"]
assert inbound.owner_user_id == "deerflow-user-1"
assert inbound.workspace_id == "100"
assert inbound.user_id == "42"
assert inbound.chat_id == "100"
assert inbound.text == "hello"
+13
View File
@@ -137,6 +137,19 @@ class TestThreadMetaRepository:
async def test_update_metadata_nonexistent_is_noop(self, repo):
await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise
@pytest.mark.anyio
async def test_update_owner_with_bypass_moves_row(self, repo):
await repo.create("t1", user_id="default", metadata={"source": "channel"})
await repo.update_owner("t1", "owner-1", user_id=None)
owner_row = await repo.get("t1", user_id="owner-1")
default_row = await repo.get("t1", user_id="default")
assert owner_row is not None
assert owner_row["user_id"] == "owner-1"
assert owner_row["metadata"] == {"source": "channel"}
assert default_row is None
# --- search with metadata filter (SQL push-down) ---
@pytest.mark.anyio
+32
View File
@@ -1,4 +1,5 @@
import re
from types import SimpleNamespace
from unittest.mock import patch
import pytest
@@ -218,6 +219,37 @@ def test_create_thread_returns_iso_timestamps() -> None:
assert body["created_at"] == body["updated_at"]
def test_internal_owner_header_assigns_thread_to_owner() -> None:
import asyncio
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
store = InMemoryStore()
checkpointer = InMemorySaver()
thread_store = MemoryThreadMetaStore(store)
request = SimpleNamespace(
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "owner-1"},
state=SimpleNamespace(user=SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)),
app=SimpleNamespace(state=SimpleNamespace(checkpointer=checkpointer, thread_store=thread_store)),
)
async def _scenario():
response = await threads.create_thread(
threads.ThreadCreateRequest(thread_id="channel-thread", metadata={}),
request,
)
owner_row = await thread_store.get("channel-thread", user_id="owner-1")
internal_row = await thread_store.get("channel-thread", user_id="default")
return response, owner_row, internal_row
response, owner_row, internal_row = asyncio.run(_scenario())
assert response.thread_id == "channel-thread"
assert owner_row is not None
assert owner_row["user_id"] == "owner-1"
assert internal_row is None
def test_get_thread_returns_iso_for_legacy_unix_record() -> None:
"""A thread record written by older versions stores ``time.time()``
floats. ``get_thread`` must transparently surface them as ISO so the
@@ -5,18 +5,22 @@ Verifies:
- ``_count_tokens`` falls back to character estimation when tiktoken is
unavailable or the encoding fails to load.
- ``warm_tiktoken_cache`` populates the cache on success.
- An in-flight tiktoken load prevents duplicate blocking downloads.
"""
from __future__ import annotations
import threading
from unittest import mock
from deerflow.agents.memory.prompt import (
_count_tokens,
_get_tiktoken_encoding,
_tiktoken_encoding_cache,
format_memory_for_injection,
warm_tiktoken_cache,
)
from deerflow.config.memory_config import MemoryConfig
# ---------------------------------------------------------------------------
# _get_tiktoken_encoding
@@ -62,14 +66,103 @@ class TestGetTiktokenEncoding:
assert enc is fake_enc
tiktoken.get_encoding.assert_not_called()
def test_returns_none_and_warns_on_get_encoding_failure(self, monkeypatch):
def test_returns_none_and_caches_failure_sentinel(self, monkeypatch):
"""A failed load is cached (with a timestamp) so it is not re-attempted (no repeated network download)."""
_tiktoken_encoding_cache.pop("bogus_encoding", None)
import tiktoken
monkeypatch.setattr(tiktoken, "get_encoding", mock.Mock(side_effect=OSError("download failed")))
get_encoding = mock.Mock(side_effect=OSError("download failed"))
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
result = _get_tiktoken_encoding("bogus_encoding")
assert result is None
assert "bogus_encoding" not in _tiktoken_encoding_cache
# The failure is remembered as a (None, timestamp) tuple.
assert "bogus_encoding" in _tiktoken_encoding_cache
cached = _tiktoken_encoding_cache["bogus_encoding"]
assert isinstance(cached, tuple)
assert cached[0] is None
# A second call must NOT re-attempt get_encoding (avoids re-blocking on
# the network download in restricted environments — see #3429).
result2 = _get_tiktoken_encoding("bogus_encoding")
assert result2 is None
assert get_encoding.call_count == 1
# Cleanup module-level cache to avoid cross-test leakage.
_tiktoken_encoding_cache.pop("bogus_encoding", None)
def test_failure_self_heals_after_cooldown(self, monkeypatch):
"""After the retry cooldown expires, a transient failure is re-attempted and can recover."""
_tiktoken_encoding_cache.pop("flaky_encoding", None)
import tiktoken
fake_enc = mock.Mock()
# First call fails, second call (after cooldown) succeeds.
get_encoding = mock.Mock(side_effect=[OSError("transient outage"), fake_enc])
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
# Initial failure is cached.
assert _get_tiktoken_encoding("flaky_encoding") is None
assert get_encoding.call_count == 1
# Within the cooldown window: no retry, immediate fallback.
assert _get_tiktoken_encoding("flaky_encoding") is None
assert get_encoding.call_count == 1
# Simulate the cooldown having elapsed by ageing the cached timestamp.
from deerflow.agents.memory import prompt as prompt_module
_, _failed_at = _tiktoken_encoding_cache["flaky_encoding"]
_tiktoken_encoding_cache["flaky_encoding"] = (
None,
_failed_at - prompt_module._TIKTOKEN_RETRY_COOLDOWN_S - 1,
)
# Now the load is retried and recovers to accurate counting.
assert _get_tiktoken_encoding("flaky_encoding") is fake_enc
assert get_encoding.call_count == 2
_tiktoken_encoding_cache.pop("flaky_encoding", None)
def test_in_flight_load_returns_none_without_duplicate_get_encoding(self, monkeypatch):
"""Concurrent callers must not start duplicate blocking BPE downloads."""
_tiktoken_encoding_cache.pop("slow_encoding", None)
import tiktoken
started = threading.Event()
release = threading.Event()
fake_enc = mock.Mock()
def slow_get_encoding(_name):
started.set()
assert release.wait(timeout=2), "test timed out waiting to release slow get_encoding"
return fake_enc
get_encoding = mock.Mock(side_effect=slow_get_encoding)
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
result: dict[str, object | None] = {}
def load_encoding():
result["encoding"] = _get_tiktoken_encoding("slow_encoding")
thread = threading.Thread(target=load_encoding)
thread.start()
try:
assert started.wait(timeout=1), "slow get_encoding did not start"
# While the first call is still blocked, a second call should see
# the in-flight sentinel and fall back immediately instead of
# starting another potentially long network download.
assert _get_tiktoken_encoding("slow_encoding") is None
assert get_encoding.call_count == 1
finally:
release.set()
thread.join(timeout=2)
_tiktoken_encoding_cache.pop("slow_encoding", None)
assert result["encoding"] is fake_enc
assert get_encoding.call_count == 1
# ---------------------------------------------------------------------------
@@ -115,6 +208,45 @@ class TestCountTokens:
result = _count_tokens(text, encoding_name="test_enc")
assert result == len(text) // 4
def test_use_tiktoken_false_returns_char_estimate_without_touching_tiktoken(self, monkeypatch):
"""use_tiktoken=False must never call tiktoken (guarantees no BPE download)."""
# Spy on both the encoding loader and tiktoken.get_encoding directly.
get_encoding_spy = mock.Mock(side_effect=AssertionError("get_encoding must not be called"))
loader_spy = mock.Mock(side_effect=AssertionError("_get_tiktoken_encoding must not be called"))
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", get_encoding_spy)
monkeypatch.setattr("deerflow.agents.memory.prompt._get_tiktoken_encoding", loader_spy)
text = "Hello, world! This is a network-free count."
result = _count_tokens(text, use_tiktoken=False)
assert result == len(text) // 4
get_encoding_spy.assert_not_called()
loader_spy.assert_not_called()
def test_cjk_estimate_is_denser_than_plain_quarter(self, monkeypatch):
"""CJK text should estimate more tokens than the plain len // 4 heuristic.
CJK characters are ~2 chars/token, so the char-based estimate must not
under-fill the budget the way ``len(text) // 4`` would.
"""
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
# "User prefers concise answers" rendered in CJK (Chinese) characters.
text = "\u7528\u6237\u504f\u597d\u7b80\u6d01\u7684\u4e2d\u6587\u56de\u7b54\u5e76\u5173\u6ce8\u91d1\u878d\u9886\u57df"
result = _count_tokens(text)
# Each CJK char counts as ~1/2 token (vs 1/4 for the plain heuristic).
assert result == len(text) // 2
assert result > len(text) // 4
def test_cjk_estimate_combines_cjk_and_non_cjk_characters(self, monkeypatch):
"""Mixed-language text should apply the CJK density only to CJK chars."""
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
# ASCII words mixed with CJK (Chinese) characters: "User" + "likes" + "Python and data analysis".
text = "User\u559c\u6b22Python\u548c\u6570\u636e\u5206\u6790"
cjk = sum(1 for ch in text if "\u4e00" <= ch <= "\u9fff")
result = _count_tokens(text)
assert result == (len(text) - cjk) // 4 + cjk // 2
# ---------------------------------------------------------------------------
# warm_tiktoken_cache
@@ -146,3 +278,69 @@ class TestWarmTiktokenCache:
def test_returns_false_when_tiktoken_unavailable(self, monkeypatch):
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
assert warm_tiktoken_cache() is False
# ---------------------------------------------------------------------------
# format_memory_for_injection token_counting strategy
# ---------------------------------------------------------------------------
class TestFormatMemoryForInjectionTokenCounting:
"""Verify the use_tiktoken flag is honoured end-to-end."""
@staticmethod
def _sample_memory() -> dict:
return {
"facts": [
{"content": "User prefers concise answers.", "category": "preference", "confidence": 0.9},
{"content": "User works in the finance domain.", "category": "context", "confidence": 0.8},
],
}
def test_use_tiktoken_false_never_touches_tiktoken(self, monkeypatch):
"""With use_tiktoken=False, formatting must not call tiktoken at all."""
get_encoding_spy = mock.Mock(side_effect=AssertionError("get_encoding must not be called"))
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", get_encoding_spy)
result = format_memory_for_injection(self._sample_memory(), max_tokens=2000, use_tiktoken=False)
assert "User prefers concise answers." in result
get_encoding_spy.assert_not_called()
def test_use_tiktoken_true_uses_encoding(self, monkeypatch):
"""With use_tiktoken=True (default), the cached encoding is used for counting."""
fake_enc = mock.Mock()
fake_enc.encode.side_effect = lambda text: list(range(len(text)))
monkeypatch.setattr(
"deerflow.agents.memory.prompt._get_tiktoken_encoding",
mock.Mock(return_value=fake_enc),
)
result = format_memory_for_injection(self._sample_memory(), max_tokens=2000, use_tiktoken=True)
assert "User prefers concise answers." in result
assert fake_enc.encode.called
def test_empty_memory_returns_empty(self):
assert format_memory_for_injection({}, max_tokens=2000, use_tiktoken=False) == ""
# ---------------------------------------------------------------------------
# MemoryConfig.token_counting
# ---------------------------------------------------------------------------
class TestMemoryConfigTokenCounting:
"""Verify the new config field defaults and validation."""
def test_default_is_tiktoken(self):
"""Default must remain tiktoken so existing deployments are unaffected."""
assert MemoryConfig().token_counting == "tiktoken"
def test_accepts_char(self):
assert MemoryConfig(token_counting="char").token_counting == "char"
def test_rejects_invalid_value(self):
import pytest
from pydantic import ValidationError
with pytest.raises(ValidationError):
MemoryConfig(token_counting="invalid")
+2
View File
@@ -820,6 +820,7 @@ dependencies = [
{ name = "agent-sandbox" },
{ name = "aiosqlite" },
{ name = "alembic" },
{ name = "cryptography" },
{ name = "ddgs" },
{ name = "dotenv" },
{ name = "duckdb" },
@@ -871,6 +872,7 @@ requires-dist = [
{ name = "aiosqlite", specifier = ">=0.19" },
{ name = "alembic", specifier = ">=1.13" },
{ name = "asyncpg", marker = "extra == 'postgres'", specifier = ">=0.29" },
{ name = "cryptography", specifier = ">=43.0.0" },
{ name = "ddgs", specifier = ">=9.10.0" },
{ name = "dotenv", specifier = ">=0.9.9" },
{ name = "duckdb", specifier = ">=1.4.4" },
+54 -2
View File
@@ -15,7 +15,7 @@
# ============================================================================
# Bump this number when the config schema changes.
# Run `make config-upgrade` to merge new fields into your local config.yaml.
config_version: 11
config_version: 12
# ============================================================================
# Logging
@@ -768,8 +768,12 @@ sandbox:
allow_host_bash: false
# Optional: Mount additional host directories into the sandbox.
# Each mount maps a host path to a virtual container path accessible by the agent.
# Note: with LocalSandboxProvider under `make up` (docker-compose), host_path is
# checked from inside the deer-flow-gateway container — you must also bind-mount
# the same directory into services.gateway.volumes in docker/docker-compose.yaml
# for this mount to take effect (see issue #3244).
# mounts:
# - host_path: /home/user/my-project # Absolute path on the host machine
# - host_path: /home/user/my-project # Absolute path; see note above for Docker mode
# container_path: /mnt/my-project # Virtual path inside the sandbox
# read_only: true # Whether the mount is read-only (default: false)
@@ -1020,6 +1024,15 @@ memory:
fact_confidence_threshold: 0.7 # Minimum confidence for storing facts
injection_enabled: true # Whether to inject memory into system prompt
max_injection_tokens: 2000 # Maximum tokens for memory injection
# Token counting strategy for memory-injection budgeting:
# tiktoken (default) - accurate, but the encoding's BPE data may be
# downloaded from a public network endpoint on first use. In
# network-restricted environments this download can block for a long
# time (see issues #3402 / #3429). Pre-cache the encoding or set this
# to "char" to avoid it.
# char - network-free CJK-aware character-based estimate; never touches
# tiktoken. Slightly less precise budgeting, zero network I/O.
token_counting: tiktoken
# ============================================================================
# Custom Agent Management API
@@ -1127,6 +1140,45 @@ run_events:
max_trace_content: 10240
track_token_usage: true
# ============================================================================
# User-Owned IM Channel Connections
# ============================================================================
# Lets logged-in users connect their own IM accounts from the DeerFlow frontend
# while reusing the existing `channels` runtime configuration below.
#
# Security notes:
# - No public IP, OAuth callback URL, or provider webhook is required.
# - Provider bot/app credentials stay under `channels.*`.
# - `channel_connections` stores per-user bindings and one-time connect codes.
# - Telegram uses a deep link when `bot_username` is configured.
# - Slack, Discord, Feishu, DingTalk, WeChat, and WeCom use `/connect <code>`
# through the already-running bot/app.
#
# channel_connections:
# enabled: false
#
# telegram:
# enabled: false
# bot_username: $TELEGRAM_BOT_USERNAME
#
# slack:
# enabled: false
#
# discord:
# enabled: false
#
# feishu:
# enabled: false
#
# dingtalk:
# enabled: false
#
# wechat:
# enabled: false
#
# wecom:
# enabled: false
# ============================================================================
# IM Channels Configuration
# ============================================================================
+7 -1
View File
@@ -72,7 +72,13 @@ services:
UV_INDEX_URL: ${UV_INDEX_URL:-https://pypi.org/simple}
UV_EXTRAS: ${UV_EXTRAS:-}
container_name: deer-flow-gateway
command: sh -c "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 --workers ${GATEWAY_WORKERS:-4}"
# Gateway hosts the agent runtime with in-process RunManager + StreamBridge
# singletons -- run state lives in this worker's memory. Default to a single
# worker: with >1 worker and no nginx sticky sessions, run cancel, SSE
# reconnect, request dedup, and per-worker IM channel services all break
# across workers until a shared (e.g. redis) stream bridge lands, which is
# not yet implemented. Override GATEWAY_WORKERS only once that is in place.
command: sh -c "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 --workers ${GATEWAY_WORKERS:-1}"
volumes:
- ${DEER_FLOW_CONFIG_PATH}:/app/backend/config.yaml:ro
- ${DEER_FLOW_EXTENSIONS_CONFIG_PATH}:/app/backend/extensions_config.json:ro
+7 -3
View File
@@ -7,8 +7,9 @@ import { defineConfig, devices } from "@playwright/test";
* so the mock-based suite is untouched.
*
* Two webServers are started: the replay gateway (:8011) and the frontend
* (:3000, pointed at the gateway). Auth uses a throwaway test account the spec
* registers at runtime no secrets.
* (:3000, pointed at the gateway). Auth-disabled mode is enabled on both
* servers so the no-cookie e2e contract is covered; specs that need session
* cookies still register a throwaway test account at runtime.
*/
export default defineConfig({
testDir: "./tests/e2e-real-backend",
@@ -38,7 +39,10 @@ export default defineConfig({
// Mount the test-only run/message seeder used by multi-run-order.spec.ts
// (#3352). The endpoint exists only on this replay gateway, never in the
// production app.
env: { DEERFLOW_ENABLE_TEST_SEED: "1" },
env: {
DEERFLOW_ENABLE_TEST_SEED: "1",
DEER_FLOW_AUTH_DISABLED: "1",
},
},
{
command: "pnpm build && pnpm start",
+6 -7
View File
@@ -1,7 +1,7 @@
import Link from "next/link";
import { redirect } from "next/navigation";
import { type ReactNode } from "react";
import { GatewayOfflineFallback } from "@/components/workspace/gateway-offline-fallback";
import { AuthProvider } from "@/core/auth/AuthProvider";
import { getServerSideUser } from "@/core/auth/server";
import { assertNever } from "@/core/auth/types";
@@ -25,18 +25,17 @@ export default async function AuthLayout({
case "unauthenticated":
return <AuthProvider initialUser={null}>{children}</AuthProvider>;
case "gateway_unavailable":
// Auth pages have no banner of their own, so render one here. The
// fallback's AuthProvider replaces the bare-HTML branch that
// previously locked users out without any logout/retry capability.
return (
<GatewayOfflineFallback renderBanner>
<div className="flex h-screen flex-col items-center justify-center gap-4">
<p className="text-muted-foreground">
Service temporarily unavailable.
</p>
<Link
href="/login"
className="bg-primary text-primary-foreground hover:bg-primary/90 rounded-md px-4 py-2 text-sm"
>
Retry
</Link>
</div>
</GatewayOfflineFallback>
);
case "config_error":
throw new Error(result.message);
@@ -72,20 +72,21 @@ export default function AgentChatPage() {
loadMoreHistory,
} = useThreadStream({
threadId: isNewThread ? undefined : threadId,
displayThreadId: threadId,
context: { ...settings.context, agent_name: agent_name },
isMock,
onSend: () => {
setIsWelcomeMode(false);
},
onStart: (createdThreadId) => {
setThreadId(createdThreadId);
setIsNewThread(false);
// ! Important: Never use next.js router for navigation in this case, otherwise it will cause the thread to re-mount and lose all states. Use native history API instead.
history.replaceState(
null,
"",
`/workspace/agents/${agent_name}/chats/${createdThreadId}`,
);
setThreadId(createdThreadId);
setIsNewThread(false);
},
onFinish: (state) => {
if (document.hidden || !document.hasFocus()) {
@@ -75,6 +75,7 @@ export default function ChatPage() {
loadMoreHistory,
} = useThreadStream({
threadId: isNewThread ? undefined : threadId,
displayThreadId: threadId,
context: settings.context,
isMock,
// onSend only animates the UI; do NOT flip `isNewThread` here — the
@@ -84,10 +85,10 @@ export default function ChatPage() {
setIsWelcomeMode(false);
},
onStart: (createdThreadId) => {
setThreadId(createdThreadId);
setIsNewThread(false);
// ! Important: Never use next.js router for navigation in this case, otherwise it will cause the thread to re-mount and lose all states. Use native history API instead.
history.replaceState(null, "", `/workspace/chats/${createdThreadId}`);
setThreadId(createdThreadId);
setIsNewThread(false);
},
onFinish: (state) => {
if (document.hidden || !document.hasFocus()) {
+84 -9
View File
@@ -1,34 +1,77 @@
"use client";
import Link from "next/link";
import { useEffect, useMemo, useState } from "react";
import { useEffect, useMemo, useRef, useState } from "react";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { ScrollArea } from "@/components/ui/scroll-area";
import {
ThreadChannelBadge,
ThreadChannelIcon,
} from "@/components/workspace/thread-channel-source";
import {
WorkspaceBody,
WorkspaceContainer,
WorkspaceHeader,
} from "@/components/workspace/workspace-container";
import { useI18n } from "@/core/i18n/hooks";
import { useThreads } from "@/core/threads/hooks";
import { pathOfThread, titleOfThread } from "@/core/threads/utils";
import { useInfiniteThreads } from "@/core/threads/hooks";
import {
channelSourceOfThread,
pathOfThread,
titleOfThread,
} from "@/core/threads/utils";
import { formatTimeAgo } from "@/core/utils/datetime";
export default function ChatsPage() {
const { t } = useI18n();
const { data: threads } = useThreads();
const {
data: infiniteThreads,
fetchNextPage,
hasNextPage,
isFetchingNextPage,
} = useInfiniteThreads();
const threads = useMemo(
() => infiniteThreads?.pages.flat() ?? [],
[infiniteThreads],
);
const [search, setSearch] = useState("");
const isSearching = search.trim().length > 0;
useEffect(() => {
document.title = `${t.pages.chats} - ${t.pages.appName}`;
}, [t.pages.chats, t.pages.appName]);
const filteredThreads = useMemo(() => {
return threads?.filter((thread) => {
return threads.filter((thread) => {
return titleOfThread(thread).toLowerCase().includes(search.toLowerCase());
});
}, [threads, search]);
// Sentinel-based auto load-more for the unfiltered list (issue #3482).
// In search mode we deliberately do NOT auto-paginate, otherwise an empty
// filtered view would keep the sentinel in the viewport and drain the
// entire backend list one page at a time. Searching falls back to an
// explicit button so users can still reach older conversations on demand.
const sentinelRef = useRef<HTMLDivElement | null>(null);
useEffect(() => {
const element = sentinelRef.current;
if (!element || !hasNextPage || isSearching) {
return;
}
const observer = new IntersectionObserver(
([entry]) => {
if (entry?.isIntersecting && hasNextPage && !isFetchingNextPage) {
void fetchNextPage();
}
},
{ rootMargin: "200px 0px 200px 0px" },
);
observer.observe(element);
return () => observer.disconnect();
}, [fetchNextPage, hasNextPage, isFetchingNextPage, isSearching]);
return (
<WorkspaceContainer>
<WorkspaceHeader></WorkspaceHeader>
@@ -47,11 +90,20 @@ export default function ChatsPage() {
<main className="min-h-0 flex-1">
<ScrollArea className="size-full py-4">
<div className="mx-auto flex size-full max-w-(--container-width-md) flex-col">
{filteredThreads?.map((thread) => (
{filteredThreads.map((thread) => {
const channelSource = channelSourceOfThread(thread);
return (
<Link key={thread.thread_id} href={pathOfThread(thread)}>
<div className="flex flex-col gap-2 border-b p-4">
<div>
<div>{titleOfThread(thread)}</div>
<div className="flex min-w-0 items-center gap-2">
<ThreadChannelIcon source={channelSource} />
<div className="min-w-0 flex-1 truncate">
{titleOfThread(thread)}
</div>
<ThreadChannelBadge
source={channelSource}
className="hidden sm:inline-flex"
/>
</div>
{thread.updated_at && (
<div className="text-muted-foreground text-sm">
@@ -60,7 +112,30 @@ export default function ChatsPage() {
)}
</div>
</Link>
))}
);
})}
{hasNextPage && !isSearching && (
<div
ref={sentinelRef}
aria-hidden="true"
className="h-px w-full"
data-testid="chats-page-sentinel"
/>
)}
{hasNextPage && isSearching && (
<div className="flex justify-center p-4">
<Button
variant="outline"
onClick={() => void fetchNextPage()}
disabled={isFetchingNextPage}
data-testid="chats-page-load-more"
>
{isFetchingNextPage
? t.chats.loadingMore
: t.chats.loadMoreToSearch}
</Button>
</div>
)}
</div>
</ScrollArea>
</main>
+7 -25
View File
@@ -1,6 +1,6 @@
import Link from "next/link";
import { redirect } from "next/navigation";
import { GatewayOfflineFallback } from "@/components/workspace/gateway-offline-fallback";
import { AuthProvider } from "@/core/auth/AuthProvider";
import { getServerSideUser } from "@/core/auth/server";
import { assertNever } from "@/core/auth/types";
@@ -28,31 +28,13 @@ export default async function WorkspaceLayout({
case "unauthenticated":
redirect("/login");
case "gateway_unavailable":
// GatewayOfflineFallback supplies the AuthProvider; WorkspaceContent
// already mounts the banner inside its sidebar layout, so renderBanner
// stays false here to avoid double-mounting.
return (
<div className="flex h-screen flex-col items-center justify-center gap-4">
<p className="text-muted-foreground">
Service temporarily unavailable.
</p>
<p className="text-muted-foreground text-xs">
The backend may be restarting. Please wait a moment and try again.
</p>
<div className="flex gap-3">
<Link
href="/workspace"
className="bg-primary text-primary-foreground hover:bg-primary/90 rounded-md px-4 py-2 text-sm"
>
Retry
</Link>
<form action="/api/v1/auth/logout" method="post">
<button
type="submit"
className="text-muted-foreground hover:bg-muted rounded-md border px-4 py-2 text-sm"
>
Logout &amp; Reset
</button>
</form>
</div>
</div>
<GatewayOfflineFallback>
<WorkspaceContent gatewayUnavailable>{children}</WorkspaceContent>
</GatewayOfflineFallback>
);
case "config_error":
throw new Error(result.message);
@@ -4,6 +4,7 @@ import { Toaster } from "sonner";
import { QueryClientProvider } from "@/components/query-client-provider";
import { SidebarInset, SidebarProvider } from "@/components/ui/sidebar";
import { CommandPalette } from "@/components/workspace/command-palette";
import { GatewayOfflineBanner } from "@/components/workspace/gateway-offline-banner";
import { WorkspaceSidebar } from "@/components/workspace/workspace-sidebar";
function parseSidebarOpenCookie(
@@ -16,7 +17,11 @@ function parseSidebarOpenCookie(
export async function WorkspaceContent({
children,
}: Readonly<{ children: React.ReactNode }>) {
gatewayUnavailable = false,
}: Readonly<{
children: React.ReactNode;
gatewayUnavailable?: boolean;
}>) {
const cookieStore = await cookies();
const initialSidebarOpen = parseSidebarOpenCookie(
cookieStore.get("sidebar_state")?.value,
@@ -26,7 +31,10 @@ export async function WorkspaceContent({
<QueryClientProvider>
<SidebarProvider className="h-screen" defaultOpen={initialSidebarOpen}>
<WorkspaceSidebar />
<SidebarInset className="min-w-0">{children}</SidebarInset>
<SidebarInset className="min-w-0">
<GatewayOfflineBanner gatewayUnavailable={gatewayUnavailable} />
{children}
</SidebarInset>
</SidebarProvider>
<CommandPalette />
<Toaster position="top-center" />
@@ -0,0 +1,184 @@
"use client";
import { MessageCircleIcon } from "lucide-react";
import type { SVGProps } from "react";
import { cn } from "@/lib/utils";
type ChannelProviderIconProps = SVGProps<SVGSVGElement> & {
provider: string;
};
export function ChannelProviderIcon({
provider,
className,
...props
}: ChannelProviderIconProps) {
const normalizedProvider = provider.toLowerCase();
if (normalizedProvider === "telegram") {
return (
<svg
viewBox="0 0 24 24"
aria-hidden="true"
className={cn("size-5", className)}
{...props}
>
<circle cx="12" cy="12" r="11" fill="#2AABEE" />
<path
fill="#FFFFFF"
d="M17.4 7.2 15.7 16c-.1.7-.5.9-1 .6l-2.8-2.1-1.4 1.3c-.1.2-.3.3-.6.3l.2-2.9 5.3-4.8c.2-.2 0-.3-.3-.1l-6.6 4.1-2.8-.9c-.6-.2-.6-.6.1-.8l10.9-4.2c.5-.2.9.1.7.7Z"
/>
</svg>
);
}
if (normalizedProvider === "slack") {
return (
<svg
viewBox="0 0 256 256"
aria-hidden="true"
className={cn("size-5", className)}
{...props}
>
<path
fill="#E01E5A"
d="M53.841 161.32c0 14.832-11.987 26.82-26.819 26.82S.203 176.152.203 161.32c0-14.831 11.987-26.818 26.82-26.818H53.84zm13.41 0c0-14.831 11.987-26.818 26.819-26.818s26.819 11.987 26.819 26.819v67.047c0 14.832-11.987 26.82-26.82 26.82c-14.83 0-26.818-11.988-26.818-26.82z"
/>
<path
fill="#36C5F0"
d="M94.07 53.638c-14.832 0-26.82-11.987-26.82-26.819S79.239 0 94.07 0s26.819 11.987 26.819 26.819v26.82zm0 13.613c14.832 0 26.819 11.987 26.819 26.819s-11.987 26.819-26.82 26.819H26.82C11.987 120.889 0 108.902 0 94.069c0-14.83 11.987-26.818 26.819-26.818z"
/>
<path
fill="#2EB67D"
d="M201.55 94.07c0-14.832 11.987-26.82 26.818-26.82s26.82 11.988 26.82 26.82s-11.988 26.819-26.82 26.819H201.55zm-13.41 0c0 14.832-11.988 26.819-26.82 26.819c-14.831 0-26.818-11.987-26.818-26.82V26.82C134.502 11.987 146.489 0 161.32 0s26.819 11.987 26.819 26.819z"
/>
<path
fill="#ECB22E"
d="M161.32 201.55c14.832 0 26.82 11.987 26.82 26.818s-11.988 26.82-26.82 26.82c-14.831 0-26.818-11.988-26.818-26.82V201.55zm0-13.41c-14.831 0-26.818-11.988-26.818-26.82c0-14.831 11.987-26.818 26.819-26.818h67.25c14.832 0 26.82 11.987 26.82 26.819s-11.988 26.819-26.82 26.819z"
/>
</svg>
);
}
if (normalizedProvider === "discord") {
return (
<svg
viewBox="0 0 24 24"
aria-hidden="true"
className={cn("size-5", className)}
{...props}
>
<circle cx="12" cy="12" r="11" fill="#5865F2" />
<path
fill="#FFFFFF"
d="M8.1 8.4c1.4-.6 2.7-.7 3.9-.7s2.5.1 3.9.7c1 1.5 1.5 3.1 1.4 4.8-.9.7-1.8 1.1-2.8 1.3l-.7-1.1c.4-.1.7-.3 1.1-.5-.3.1-.6.3-.9.4-.7.3-1.4.4-2 .4s-1.3-.1-2-.4c-.3-.1-.6-.2-.9-.4.3.2.7.4 1.1.5l-.7 1.1c-1-.2-1.9-.6-2.8-1.3-.1-1.7.4-3.3 1.4-4.8Zm2.1 3.9c.5 0 .9-.5.9-1.1s-.4-1.1-.9-1.1-.9.5-.9 1.1.4 1.1.9 1.1Zm3.6 0c.5 0 .9-.5.9-1.1s-.4-1.1-.9-1.1-.9.5-.9 1.1.4 1.1.9 1.1Z"
/>
</svg>
);
}
if (normalizedProvider === "feishu") {
return (
<svg
viewBox="0 0 24 24"
aria-hidden="true"
className={cn("size-5", className)}
{...props}
>
<rect
x="1.25"
y="1.25"
width="21.5"
height="21.5"
rx="5.25"
fill="#FFFFFF"
stroke="#E5E7EB"
strokeWidth=".5"
/>
<path
d="M6.1 4.5h8.3c.9 0 1.7.4 2.2 1.1a16 16 0 0 1 2.9 6.2c-1.8-.8-3.8-.9-5.9-.3L5.6 5.8c-.6-.5-.3-1.3.5-1.3Z"
fill="#14D6C5"
/>
<path
d="M3.2 8.9c3.6 3.4 7.5 5.7 11.7 6.8 2.7.7 5.1.4 7-.6-1.6 3.1-5.2 5.4-9.4 5.4-3.4 0-6.7-.9-9.2-2.6a2 2 0 0 1-.9-1.7V9.6c0-.7.4-1 .8-.7Z"
fill="#3370FF"
/>
<path
d="M11 14.1c2.3-3.1 6.1-4.6 10.5-3.3l.8.2-2.6 4.1a6.3 6.3 0 0 1-6 2.9c-1.9-.2-3.9-.8-5.9-1.7 1.1-.7 2.2-1.4 3.2-2.2Z"
fill="#1E3A9F"
/>
</svg>
);
}
if (normalizedProvider === "dingtalk") {
return (
<svg
viewBox="0 0 1024 1024"
aria-hidden="true"
className={cn("size-5", className)}
{...props}
>
<g transform="translate(512 512) scale(1.35) translate(-512 -512)">
<path
fill="#0B86FF"
d="M739 449.3c-1 4.2-3.5 10.4-7 17.8h.1l-.4.7c-20.3 43.1-73.1 127.7-73.1 127.7l-.3-.5-15.5 26.8h74.5L575.1 810l32.3-128h-58.6l20.4-84.7c-16.5 3.9-35.9 9.4-59 16.8 0 0-31.2 18.2-89.9-35 0 0-39.6-34.7-16.6-43.4 9.8-3.7 47.4-8.4 77-12.3 40-5.4 64.6-8.2 64.6-8.2S422 517 392.7 512.5c-29.3-4.6-66.4-53.1-74.3-95.8 0 0-12.2-23.4 26.3-12.3s197.9 43.2 197.9 43.2-207.4-63.3-221.2-78.7-40.6-84.2-37.1-126.5c0 0 1.5-10.5 12.4-7.7 0 0 153.3 69.7 258.1 107.9 104.8 37.9 195.9 57.3 184.2 106.7Z"
/>
</g>
</svg>
);
}
if (normalizedProvider === "wechat") {
return (
<svg
viewBox="0 0 24 24"
aria-hidden="true"
className={cn("size-5", className)}
{...props}
>
<circle cx="12" cy="12" r="11" fill="#07C160" />
<path
fill="#FFFFFF"
d="M10.4 6.5c-3 0-5.4 2-5.4 4.5 0 1.4.8 2.7 2.1 3.5l-.5 1.8 2-.9c.6.1 1.2.2 1.8.2 3 0 5.4-2 5.4-4.5s-2.4-4.6-5.4-4.6Zm-1.9 3.7a.7.7 0 1 1 0-1.4.7.7 0 0 1 0 1.4Zm3.7 0a.7.7 0 1 1 0-1.4.7.7 0 0 1 0 1.4Z"
/>
<path
fill="#FFFFFF"
fillOpacity=".86"
d="M14.4 12.3c2.5 0 4.6 1.7 4.6 3.8 0 1.1-.6 2.1-1.6 2.8l.4 1.5-1.7-.8c-.5.1-1.1.2-1.7.2-2.5 0-4.6-1.7-4.6-3.8s2.1-3.7 4.6-3.7Zm-1.6 3.1a.6.6 0 1 0 0-1.2.6.6 0 0 0 0 1.2Zm3.1 0a.6.6 0 1 0 0-1.2.6.6 0 0 0 0 1.2Z"
/>
</svg>
);
}
if (normalizedProvider === "wecom") {
return (
<svg
viewBox="0 0 24 24"
aria-hidden="true"
className={cn("size-5", className)}
{...props}
>
<rect
x="1.25"
y="1.25"
width="21.5"
height="21.5"
rx="5.25"
fill="#FFFFFF"
stroke="#E5E7EB"
strokeWidth=".5"
/>
<path
fill="#168DEB"
d="m17.326 8.158-.003-.007a6.6 6.6 0 0 0-1.178-1.674c-1.266-1.307-3.067-2.19-5.102-2.417a9.3 9.3 0 0 0-2.124 0h-.001c-2.061.228-3.882 1.107-5.14 2.405a6.7 6.7 0 0 0-1.194 1.682A5.7 5.7 0 0 0 2 10.657c0 1.106.332 2.218.988 3.201l.006.01c.391.594 1.092 1.39 1.637 1.83l.983.793-.208.875.527-.267.708-.358.761.225c.467.137.955.227 1.517.29h.005q.515.06 1.026.059c.355 0 .724-.02 1.095-.06a9 9 0 0 0 1.346-.258c.095.7.43 1.337.932 1.81-.658.208-1.352.358-2.061.436-.442.048-.883.072-1.312.072q-.627 0-1.253-.072a10.7 10.7 0 0 1-1.861-.36l-2.84 1.438s-.29.131-.44.131c-.418 0-.702-.285-.702-.704 0-.252.067-.598.128-.84l.394-1.653c-.728-.586-1.563-1.544-2.052-2.287A7.76 7.76 0 0 1 0 10.658a7.7 7.7 0 0 1 .787-3.39 8.7 8.7 0 0 1 1.551-2.19c1.61-1.665 3.878-2.73 6.359-3.006a11.3 11.3 0 0 1 2.565 0c2.47.275 4.712 1.353 6.323 3.017a8.6 8.6 0 0 1 1.539 2.192c.466.945.769 1.937.769 2.978a3.06 3.06 0 0 0-2-.005c-.001-.644-.189-1.329-.564-2.09zm4.125 6.977-.024-.024-.024-.018-.024-.018-.096-.095a4.24 4.24 0 0 1-1.169-2.192q0-.038-.006-.075l-.006-.056-.035-.144a1.3 1.3 0 0 0-.358-.61 1.386 1.386 0 0 0-1.957 0 1.4 1.4 0 0 0 0 1.963c.191.191.418.311.668.371.024.012.06.012.084.012q.019 0 .041.006.023.005.042.006a4.24 4.24 0 0 1 2.231 1.186c.048.048.096.095.131.143a.323.323 0 0 0 .466 0 .35.35 0 0 0 .036-.455m-1.05 4.37-.025.025c-.119.096-.31.096-.453-.036a.326.326 0 0 1 0-.467c.047-.036.094-.083.141-.13l.002-.002a4.27 4.27 0 0 0 1.187-2.28q.005-.024.006-.043c0-.024 0-.06.012-.084a1.386 1.386 0 0 1 2.326-.67 1.4 1.4 0 0 1 0 1.964c-.167.18-.382.299-.608.359l-.143.036-.057.005q-.035.006-.075.007a4.2 4.2 0 0 0-2.183 1.173l-.095.096q-.009.01-.018.024t-.018.024m-4.392-1.053.024.024.024.018q.015.009.024.018l.096.096a4.25 4.25 0 0 1 1.169 2.19q0 .04.006.076.005.03.006.057l.035.143c.06.228.18.443.358.611.537.539 1.42.539 1.957 0a1.4 1.4 0 0 0 0-1.964 1.4 1.4 0 0 0-.668-.371c-.024-.012-.06-.012-.084-.012q-.018 0-.041-.006l-.042-.006a4.25 4.25 0 0 1-2.231-1.185 1.4 1.4 0 0 1-.131-.144.323.323 0 0 0-.466 0 .325.325 0 0 0-.036.455m1.039-4.358.024-.024a.32.32 0 0 1 .453.035.326.326 0 0 1 0 .467c-.047.036-.094.083-.141.13l-.002.002a4.27 4.27 0 0 0-1.187 2.281l-.006.042c0 .024 0 .06-.012.084a1.386 1.386 0 0 1-2.326.67 1.4 1.4 0 0 1 0-1.963c.166-.18.381-.3.608-.36l.143-.035q.026 0 .056-.006.037-.005.075-.006a4.2 4.2 0 0 0 2.183-1.174l.096-.095.018-.025z"
/>
</svg>
);
}
return (
<MessageCircleIcon aria-hidden="true" className={cn("size-5", className)} />
);
}
@@ -0,0 +1,159 @@
"use client";
import { LoaderCircleIcon } from "lucide-react";
import {
type CSSProperties,
type FormEvent,
useEffect,
useMemo,
useState,
} from "react";
import { Button } from "@/components/ui/button";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import { Input } from "@/components/ui/input";
import type {
ChannelProvider,
ChannelRuntimeConfigValues,
} from "@/core/channels/types";
import { useI18n } from "@/core/i18n/hooks";
type ChannelRuntimeConfigDialogProps = {
provider: ChannelProvider | null;
open: boolean;
submitting: boolean;
onOpenChange: (open: boolean) => void;
onSubmit: (
provider: ChannelProvider,
values: ChannelRuntimeConfigValues,
) => void;
};
type SecretInputStyle = CSSProperties & {
WebkitTextSecurity?: "disc";
};
const SECRET_INPUT_STYLE: SecretInputStyle = {
WebkitTextSecurity: "disc",
};
export function ChannelRuntimeConfigDialog({
provider,
open,
submitting,
onOpenChange,
onSubmit,
}: ChannelRuntimeConfigDialogProps) {
const { t } = useI18n();
const [values, setValues] = useState<ChannelRuntimeConfigValues>({});
const fields = useMemo(
() => provider?.credential_fields ?? [],
[provider?.credential_fields],
);
const credentialValues = useMemo<ChannelRuntimeConfigValues>(
() => provider?.credential_values ?? {},
[provider?.credential_values],
);
useEffect(() => {
if (!open || !provider) {
setValues({});
return;
}
setValues(
Object.fromEntries(
fields.map((field) => [field.name, credentialValues[field.name] ?? ""]),
) as ChannelRuntimeConfigValues,
);
}, [credentialValues, fields, open, provider]);
if (!provider) {
return null;
}
const isEditing = provider.configured;
const handleSubmit = (event: FormEvent<HTMLFormElement>) => {
event.preventDefault();
onSubmit(provider, values);
};
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent>
<form onSubmit={handleSubmit} className="space-y-4">
<DialogHeader>
<DialogTitle>
{isEditing
? t.channels.setupEditTitle(provider.display_name)
: t.channels.setupTitle(provider.display_name)}
</DialogTitle>
<DialogDescription>{t.channels.setupDescription}</DialogDescription>
</DialogHeader>
<div className="space-y-3">
{fields.map((field) => {
const inputId = `channel-${provider.provider}-${field.name}`;
const isSecretField = field.type === "password";
return (
<div key={field.name} className="space-y-1.5">
<label
htmlFor={inputId}
className="text-sm leading-none font-medium"
>
{field.label}
</label>
<Input
id={inputId}
type="text"
value={values[field.name] ?? ""}
required={field.required}
autoComplete="off"
autoCorrect="off"
autoCapitalize="none"
spellCheck={false}
data-1p-ignore={isSecretField ? "true" : undefined}
data-bwignore={isSecretField ? "true" : undefined}
data-form-type={isSecretField ? "other" : undefined}
data-lpignore={isSecretField ? "true" : undefined}
style={isSecretField ? SECRET_INPUT_STYLE : undefined}
onChange={(event) => {
setValues((current) => ({
...current,
[field.name]: event.target.value,
}));
}}
/>
</div>
);
})}
</div>
<DialogFooter>
<Button
type="button"
variant="outline"
disabled={submitting}
onClick={() => onOpenChange(false)}
>
{t.common.cancel}
</Button>
<Button type="submit" disabled={submitting}>
{submitting ? (
<LoaderCircleIcon className="animate-spin" />
) : null}
{isEditing ? t.channels.saveChanges : t.channels.saveAndConnect}
</Button>
</DialogFooter>
</form>
</DialogContent>
</Dialog>
);
}

Some files were not shown because too many files have changed in this diff Show More