From aa015462a7e9003c0f6973c66655cea25f9ba23f Mon Sep 17 00:00:00 2001 From: DanielWalnut <45447813+hetaoBackend@users.noreply.github.com> Date: Fri, 12 Jun 2026 15:24:58 +0800 Subject: [PATCH] feat(im): Add user-owned IM channel connections (#3487) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add user-owned IM channel connections * Fix dev startup and channel connect popup * Use async channel connect flow * Harden dev service daemon startup * Support local IM channel connections * Align IM connections with local channels * Fix safe user id digest algorithm * Address Copilot IM channel feedback * Address IM channel review comments * Support all integrated IM channel connections * Format additional channel connection tests * Keep unavailable channel connect buttons clickable * Fix IM channel provider icons * Add runtime setup for enabled IM channels * Guard global shortcut key handling * Keep configured IM channels editable * Avoid password autofill for channel secrets * Make channel threads visible to connection owners * Persist IM runtime config locally * Allow disconnecting runtime IM channels * Route no-auth channel sessions to local user * Use default user for auth-disabled local mode * Show IM channel source on threads * Prefill IM channel runtime config * Reflect IM channel runtime health * Ignore Feishu message read events * Ignore Feishu non-content message events * Let setup wizard enable IM channels * Fix frontend formatting after merge * Stabilize backend tests without local config * Isolate channel runtime config tests * Address channel connection review comments * Use sha256 user buckets with legacy migration * Ensure runtime IM channels are ready after restart * Persist disconnected IM channel state * Address channel connection review comments * Address channel connection review findings Frontend connect flow: - Open the runtime-config dialog only when a provider still needs credentials; configured providers go straight to the connect flow, so the binding-code/deep-link path is reachable from the UI again. - After saving credentials, continue into the connect flow when a user binding is still required (multi-user mode) instead of stopping at a "Connected" toast. - Extract shared provider-state helpers to core/channels/provider-state and add unit + e2e coverage for the direct-connect and configure-then-connect paths. Provider status semantics: - Report connection_status from the user's newest connection row; with no binding it is not_connected, except in auth-disabled local mode where a configured running channel is effectively connected. Concurrency and event-loop correctness: - Offload ChannelRuntimeConfigStore construction and writes, channel service construction, and Slack connection replies to threads; add a tests/blocking_io/ anchor for the runtime-config handlers. - Consume binding codes with a conditional UPDATE so a code can only be used once under concurrent workers; retry upsert_connection as an update when a concurrent insert wins the unique constraint. - Serialize ensure_channel_ready per channel so concurrent provider polls cannot double-start a channel worker. Config and migration hardening: - Stop mutating the get_app_config()-cached Telegram provider config; the runtime store now owns the UI-entered bot username. - Register channel_connections in STARTUP_ONLY_FIELDS with the standardized startup-only Field description. - Match the legacy unsafe-id bucket by recomputing its exact SHA-1 name so another user's same-prefix bucket can never be migrated. - Remove the unused Telegram process_webhook_update path and document src/core/channels in the frontend docs. Co-Authored-By: Claude Fable 5 * Address PR review comments on authz scoping and channel runtime Security (review feedback from ShenAC-SAC): - Scope internal-token callers to the connection owner carried in X-DeerFlow-Owner-User-Id instead of bypassing owner checks outright, in both require_permission(owner_check=True) and the stateless run endpoints. Internal callers keep access to their own and shared/legacy threads, and may claim a default-owned channel thread for its real owner, but a leaked internal token no longer grants cross-user thread access. - Require admin privileges for POST/DELETE /api/channels/{provider}/ runtime-config: runtime credentials and channel workers are instance-wide shared state (same model as the MCP config API). Read-only provider listing stays available to all users. Performance (review feedback from willem-bd): - Skip the redundant thread channel-metadata PATCH after the first successful backfill per thread. - Reuse the per-connection Slack WebClient until its token changes instead of constructing one per outbound message. - Reconcile channel readiness for all providers concurrently in GET /api/channels/providers. Also resolve the code-quality unused-import flag in the blocking-io anchor by pre-importing the channel service via importlib. Co-Authored-By: Claude Fable 5 * Fix prettier formatting in provider-state test Co-Authored-By: Claude Fable 5 * Reconcile UI runtime channel config with config reload on restart Main now reloads a channel's config.yaml entry on restart_channel() (#3514, issue #3497). Adapt the user-owned connection flow to coexist: - configure_channel() restarts with reload_config=False — the caller just supplied the authoritative config (browser-entered credentials that are never written to config.yaml), so a file reload must not clobber it with the stale on-disk entry. - _load_channel_config() re-applies the UI runtime-store overlay used at startup, so an operator-triggered restart keeps browser-entered credentials for channels without a config.yaml entry and does not resurrect a channel disconnected from the UI. - Offload the reload's disk IO (config.yaml + runtime store) with asyncio.to_thread, matching the blocking-IO policy on this branch. Co-Authored-By: Claude Fable 5 --------- Co-authored-by: Claude Fable 5 --- README.md | 2 + backend/CLAUDE.md | 35 +- backend/app/channels/commands.py | 11 + backend/app/channels/connection_identity.py | 44 + backend/app/channels/dingtalk.py | 106 +- backend/app/channels/discord.py | 66 +- backend/app/channels/feishu.py | 80 +- backend/app/channels/manager.py | 182 +++- backend/app/channels/message_bus.py | 14 + backend/app/channels/runtime_config_store.py | 154 +++ backend/app/channels/service.py | 171 +++- backend/app/channels/slack.py | 174 +++- backend/app/channels/telegram.py | 57 ++ backend/app/channels/wechat.py | 62 +- backend/app/channels/wecom.py | 59 +- backend/app/gateway/app.py | 4 + backend/app/gateway/auth_disabled.py | 6 +- backend/app/gateway/authz.py | 18 + backend/app/gateway/internal_auth.py | 27 +- .../gateway/routers/channel_connections.py | 670 ++++++++++++ backend/app/gateway/routers/threads.py | 12 +- backend/app/gateway/services.py | 145 +-- backend/docs/IM_CHANNEL_CONNECTIONS.md | 122 +++ .../harness/deerflow/config/app_config.py | 8 + .../config/channel_connections_config.py | 61 ++ .../packages/harness/deerflow/config/paths.py | 37 +- .../deerflow/config/reload_boundary.py | 3 + .../channel_connections/__init__.py | 21 + .../persistence/channel_connections/model.py | 111 ++ .../persistence/channel_connections/sql.py | 387 +++++++ .../deerflow/persistence/models/__init__.py | 18 +- .../deerflow/persistence/thread_meta/base.py | 9 + .../persistence/thread_meta/memory.py | 8 + .../deerflow/persistence/thread_meta/sql.py | 15 + .../harness/deerflow/runtime/runs/manager.py | 11 +- backend/packages/harness/pyproject.toml | 1 + .../test_channel_runtime_config_store.py | 106 ++ .../test_additional_channel_connections.py | 251 +++++ backend/tests/test_auth.py | 68 ++ backend/tests/test_auth_middleware.py | 21 +- .../tests/test_channel_connections_config.py | 56 + .../test_channel_connections_repository.py | 331 ++++++ .../tests/test_channel_connections_router.py | 963 ++++++++++++++++++ backend/tests/test_channels.py | 575 ++++++++++- backend/tests/test_csrf_middleware.py | 12 + .../tests/test_discord_channel_connections.py | 88 ++ backend/tests/test_feishu_parser.py | 25 + backend/tests/test_gateway_services.py | 95 ++ backend/tests/test_internal_auth.py | 15 + backend/tests/test_paths_user_isolation.py | 35 + backend/tests/test_reload_boundary.py | 1 + backend/tests/test_setup_wizard.py | 75 ++ .../tests/test_slack_channel_connections.py | 154 +++ .../test_stateless_runs_owner_isolation.py | 40 +- .../test_telegram_channel_connections.py | 100 ++ backend/tests/test_thread_meta_repo.py | 13 + backend/tests/test_threads_router.py | 32 + backend/uv.lock | 2 + config.example.yaml | 39 + frontend/AGENTS.md | 1 + frontend/CLAUDE.md | 1 + frontend/src/app/workspace/chats/page.tsx | 46 +- .../channels/channel-provider-icon.tsx | 184 ++++ .../channel-runtime-config-dialog.tsx | 159 +++ .../channels/workspace-channels-list.tsx | 213 ++++ .../components/workspace/recent-chat-list.tsx | 26 +- .../settings/channels-settings-page.tsx | 377 +++++++ .../workspace/settings/settings-dialog.tsx | 10 + .../workspace/thread-channel-source.tsx | 56 + .../workspace/workspace-sidebar.tsx | 2 + frontend/src/core/auth/auth-disabled-user.ts | 4 +- frontend/src/core/channels/api.ts | 117 +++ frontend/src/core/channels/hooks.ts | 96 ++ .../src/core/channels/open-connect-url.ts | 27 + frontend/src/core/channels/provider-state.ts | 22 + frontend/src/core/channels/types.ts | 53 + frontend/src/core/i18n/locales/en-US.ts | 42 + frontend/src/core/i18n/locales/types.ts | 31 + frontend/src/core/i18n/locales/zh-CN.ts | 41 + frontend/src/core/threads/hooks.ts | 67 +- .../src/core/threads/thread-search-query.ts | 86 ++ frontend/src/core/threads/utils.ts | 45 + frontend/src/hooks/use-global-shortcuts.ts | 10 +- frontend/tests/e2e/channels.spec.ts | 452 ++++++++ frontend/tests/e2e/thread-history.spec.ts | 39 + frontend/tests/e2e/utils/mock-api.ts | 6 +- frontend/tests/unit/core/channels/api.test.ts | 220 ++++ .../core/channels/open-connect-url.test.ts | 86 ++ .../unit/core/channels/provider-state.test.ts | 89 ++ .../core/threads/thread-search-query.test.ts | 19 + .../tests/unit/core/threads/utils.test.ts | 39 +- .../unit/hooks/use-global-shortcuts.test.ts | 61 ++ scripts/setup_wizard.py | 11 +- scripts/wizard/steps/channels.py | 46 + scripts/wizard/ui.py | 43 + scripts/wizard/writer.py | 27 + 96 files changed, 8585 insertions(+), 277 deletions(-) create mode 100644 backend/app/channels/connection_identity.py create mode 100644 backend/app/channels/runtime_config_store.py create mode 100644 backend/app/gateway/routers/channel_connections.py create mode 100644 backend/docs/IM_CHANNEL_CONNECTIONS.md create mode 100644 backend/packages/harness/deerflow/config/channel_connections_config.py create mode 100644 backend/packages/harness/deerflow/persistence/channel_connections/__init__.py create mode 100644 backend/packages/harness/deerflow/persistence/channel_connections/model.py create mode 100644 backend/packages/harness/deerflow/persistence/channel_connections/sql.py create mode 100644 backend/tests/blocking_io/test_channel_runtime_config_store.py create mode 100644 backend/tests/test_additional_channel_connections.py create mode 100644 backend/tests/test_channel_connections_config.py create mode 100644 backend/tests/test_channel_connections_repository.py create mode 100644 backend/tests/test_channel_connections_router.py create mode 100644 backend/tests/test_discord_channel_connections.py create mode 100644 backend/tests/test_slack_channel_connections.py create mode 100644 backend/tests/test_telegram_channel_connections.py create mode 100644 frontend/src/components/workspace/channels/channel-provider-icon.tsx create mode 100644 frontend/src/components/workspace/channels/channel-runtime-config-dialog.tsx create mode 100644 frontend/src/components/workspace/channels/workspace-channels-list.tsx create mode 100644 frontend/src/components/workspace/settings/channels-settings-page.tsx create mode 100644 frontend/src/components/workspace/thread-channel-source.tsx create mode 100644 frontend/src/core/channels/api.ts create mode 100644 frontend/src/core/channels/hooks.ts create mode 100644 frontend/src/core/channels/open-connect-url.ts create mode 100644 frontend/src/core/channels/provider-state.ts create mode 100644 frontend/src/core/channels/types.ts create mode 100644 frontend/src/core/threads/thread-search-query.ts create mode 100644 frontend/tests/e2e/channels.spec.ts create mode 100644 frontend/tests/unit/core/channels/api.test.ts create mode 100644 frontend/tests/unit/core/channels/open-connect-url.test.ts create mode 100644 frontend/tests/unit/core/channels/provider-state.test.ts create mode 100644 frontend/tests/unit/core/threads/thread-search-query.test.ts create mode 100644 frontend/tests/unit/hooks/use-global-shortcuts.test.ts create mode 100644 scripts/wizard/steps/channels.py diff --git a/README.md b/README.md index 73a8e719b..7015cb159 100644 --- a/README.md +++ b/README.md @@ -343,6 +343,8 @@ See the [MCP Server Guide](backend/docs/MCP_SERVER.md) for detailed instructions DeerFlow supports receiving tasks from messaging apps. Channels auto-start when configured — no public IP required for any of them. +DeerFlow can also expose user-owned IM channel connections in the workspace UI. When `channel_connections` is enabled, logged-in users can bind Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, or WeCom from the sidebar / Settings > Channels. It reuses the existing outbound `channels.*` transports, so no public IP or provider callback URL is required. Incoming IM messages then run under the connected DeerFlow user account. See [IM Channel Connections](backend/docs/IM_CHANNEL_CONNECTIONS.md) for setup and security notes. + | Channel | Transport | Difficulty | |---------|-----------|------------| | Telegram | Bot API (long-polling) | Easy | diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 98623bf60..66155ed32 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -234,7 +234,7 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc **Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state` — `lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work and passes it to `langgraph_runtime(app, startup_config)`. -Infrastructure fields are **restart-required**. The authoritative list lives in `packages/harness/deerflow/config/reload_boundary.py::STARTUP_ONLY_FIELDS` and is mirrored by the standardised `"startup-only:"` prefix on the corresponding `Field(description=...)` in `AppConfig`, so IDE hover on those fields surfaces the reason inline (no need to context-switch into this table). Currently registered: `database`, `checkpointer`, `run_events`, `stream_bridge`, `sandbox`, `log_level`, `channels`. Adding a new restart-required field requires updating the registry; drift is pinned by `tests/test_reload_boundary.py`. +Infrastructure fields are **restart-required**. The authoritative list lives in `packages/harness/deerflow/config/reload_boundary.py::STARTUP_ONLY_FIELDS` and is mirrored by the standardised `"startup-only:"` prefix on the corresponding `Field(description=...)` in `AppConfig`, so IDE hover on those fields surfaces the reason inline (no need to context-switch into this table). Currently registered: `database`, `checkpointer`, `run_events`, `stream_bridge`, `sandbox`, `log_level`, `channels`, `channel_connections`. Adding a new restart-required field requires updating the registry; drift is pinned by `tests/test_reload_boundary.py`. Configuration priority: 1. Explicit `config_path` argument @@ -377,8 +377,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti ### IM Channels System (`app/channels/`) -Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the DeerFlow agent via Gateway's LangGraph-compatible API. - +Bridges external messaging platforms (Feishu, Slack, Telegram, Discord, DingTalk) to the DeerFlow agent via Gateway's LangGraph-compatible API. **Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies. @@ -388,18 +387,21 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the - `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Telegram on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu incremental outbound updates - `base.py` - Abstract `Channel` base class (start/stop/send lifecycle) - `service.py` - Manages lifecycle of all configured channels from `config.yaml` -- `slack.py` / `feishu.py` / `telegram.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured) +- `slack.py` / `feishu.py` / `telegram.py` / `discord.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured) +- `app/gateway/routers/channel_connections.py` - Browser-facing user connection and disconnect APIs +- `deerflow.persistence.channel_connections` - SQL-backed user-owned connection, optional credential, connect state, and conversation store **Message Flow**: 1. External platform -> Channel impl -> `MessageBus.publish_inbound()` 2. `ChannelManager._dispatch_loop()` consumes from queue -3. For chat: look up/create thread through Gateway's LangGraph-compatible API -4. Feishu chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`) -5. Slack/Telegram chat: `runs.wait()` → extract final response → publish outbound -6. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement) -7. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails -8. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API -9. Outbound → channel callbacks → platform reply +3. For user-owned channel connections, incoming messages carry `connection_id`, `owner_user_id`, and `workspace_id`; `owner_user_id` becomes the DeerFlow run `user_id`, while the raw platform user id remains `channel_user_id` +4. For chat: look up/create thread through Gateway's LangGraph-compatible API +5. Feishu chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`) +6. Slack/Telegram chat: `runs.wait()` → extract final response → publish outbound +7. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement) +8. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails +9. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API +10. Outbound → channel callbacks → platform reply **Configuration** (`config.yaml` -> `channels`): - `langgraph_url` - LangGraph-compatible Gateway API base URL (default: `http://localhost:8001/api`) @@ -407,6 +409,17 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the - In Docker Compose, IM channels run inside the `gateway` container, so `localhost` points back to that container. Use `http://gateway:8001/api` for `langgraph_url` and `http://gateway:8001` for `gateway_url`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` / `DEER_FLOW_CHANNELS_GATEWAY_URL`. - Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token), `dingtalk` (client_id, client_secret, optional `card_template_id` for AI Card streaming) +**User-owned channel connections** (`config.yaml` -> `channel_connections`): +- Disabled by default. It is a user-binding layer on top of the existing `channels.*` runtime config, not a replacement for provider bot credentials. +- No public IP, OAuth callback URL, or provider webhook route is required by the current implementation. +- Telegram uses a deep-link `/start ` flow over the existing long-polling worker. Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom use `/connect ` over their existing outbound channel workers. +- Frontend APIs: `GET /api/channels/providers`, `GET /api/channels/connections`, `POST /api/channels/{provider}/connect`, and `DELETE /api/channels/connections/{connection_id}`. +- Browser APIs remain protected by normal Gateway auth/CSRF. Provider messages arrive through the already-configured channel workers. +- Provider-level `connection_status` reflects the user's newest connection row. With no binding it is `not_connected`, except in auth-disabled local mode where a configured running channel reports `connected` because all channel messages already route to the default user. +- Slack replies use the configured operator bot token from `channels.slack` unless per-connection credentials are present; unreadable or corrupt stored credentials are treated as unavailable. +- Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom workers resolve incoming platform identities to connection records before reaching `ChannelManager`. +- See `backend/docs/IM_CHANNEL_CONNECTIONS.md` for provider setup and operational notes. + ### Memory System (`packages/harness/deerflow/agents/memory/`) diff --git a/backend/app/channels/commands.py b/backend/app/channels/commands.py index c783899c5..86e4e9105 100644 --- a/backend/app/channels/commands.py +++ b/backend/app/channels/commands.py @@ -20,6 +20,17 @@ KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset( ) +def extract_connect_code(text: str) -> str | None: + """Extract the one-time channel binding code from a connect command.""" + parts = text.strip().split() + if len(parts) < 2: + return None + command = parts[0].lower() + if command in {"/connect", "connect"}: + return parts[1] + return None + + def is_known_channel_command(text: str) -> bool: """Return whether text starts with a registered channel control command.""" if not text.startswith("/"): diff --git a/backend/app/channels/connection_identity.py b/backend/app/channels/connection_identity.py new file mode 100644 index 000000000..162498aff --- /dev/null +++ b/backend/app/channels/connection_identity.py @@ -0,0 +1,44 @@ +"""Helpers for attaching persisted channel connection ownership to inbound messages.""" + +from __future__ import annotations + +from typing import Any + +from app.channels.message_bus import InboundMessage + + +async def attach_connection_identity( + inbound: InboundMessage, + *, + repo: Any, + provider: str, + workspace_id: str | None, + fallback_without_workspace: bool = False, +) -> InboundMessage: + """Attach connection metadata to an inbound message when a persisted binding exists.""" + if repo is None: + return inbound + + workspace_candidates: list[str | None] = [] + if workspace_id: + workspace_candidates.append(workspace_id) + if fallback_without_workspace: + workspace_candidates.append(None) + if not workspace_candidates: + return inbound + + for candidate in workspace_candidates: + connection = await repo.find_connection_by_external_identity( + provider=provider, + external_account_id=inbound.user_id, + workspace_id=candidate, + ) + if connection is None: + continue + + inbound.connection_id = connection["id"] + inbound.owner_user_id = connection["owner_user_id"] + inbound.workspace_id = connection.get("workspace_id") + return inbound + + return inbound diff --git a/backend/app/channels/dingtalk.py b/backend/app/channels/dingtalk.py index fb53ce272..85bbc30e5 100644 --- a/backend/app/channels/dingtalk.py +++ b/backend/app/channels/dingtalk.py @@ -14,7 +14,8 @@ from typing import Any import httpx from app.channels.base import Channel -from app.channels.commands import is_known_channel_command +from app.channels.commands import extract_connect_code, is_known_channel_command +from app.channels.connection_identity import attach_connection_identity from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment logger = logging.getLogger(__name__) @@ -136,6 +137,7 @@ class DingTalkChannel(Channel): self._incoming_messages: dict[str, Any] = {} self._incoming_messages_lock = threading.Lock() self._card_repliers: dict[str, Any] = {} + self._connection_repo = config.get("connection_repo") @property def supports_streaming(self) -> bool: @@ -395,6 +397,24 @@ class DingTalkChannel(Channel): text[:100], ) + connect_code = extract_connect_code(text) + if connect_code and self._connection_repo is not None: + if self._main_loop and self._main_loop.is_running(): + fut = asyncio.run_coroutine_threadsafe( + self._bind_connection_from_connect_code( + conversation_type=conversation_type, + sender_staff_id=sender_staff_id, + sender_nick=sender_nick, + conversation_id=conversation_id, + code=connect_code, + ), + self._main_loop, + ) + fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid)) + else: + logger.warning("[DingTalk] main loop not running, cannot bind channel connection") + return + if _is_dingtalk_command(text): msg_type = InboundMessageType.COMMAND else: @@ -450,11 +470,95 @@ class DingTalkChannel(Channel): return "" async def _prepare_inbound(self, chat_id: str, inbound: InboundMessage) -> None: + inbound = await self._attach_connection_identity(inbound) # Running reply must finish before publish_inbound so AI card tracks are # registered before the manager emits streaming outbounds. await self._send_running_reply(chat_id, inbound) await self.bus.publish_inbound(inbound) + @staticmethod + def _connection_workspace_id(conversation_type: str, conversation_id: str) -> str | None: + if conversation_type == _CONVERSATION_TYPE_GROUP and conversation_id: + return conversation_id + return None + + async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage: + conversation_type = str(inbound.metadata.get("conversation_type") or _CONVERSATION_TYPE_P2P) + conversation_id = str(inbound.metadata.get("conversation_id") or "") + return await attach_connection_identity( + inbound, + repo=self._connection_repo, + provider="dingtalk", + workspace_id=self._connection_workspace_id(conversation_type, conversation_id), + fallback_without_workspace=True, + ) + + async def _bind_connection_from_connect_code( + self, + *, + conversation_type: str, + sender_staff_id: str, + sender_nick: str, + conversation_id: str, + code: str, + ) -> bool: + if self._connection_repo is None or not code: + return False + + state = await self._connection_repo.consume_oauth_state(provider="dingtalk", state=code) + if state is None: + await self._send_connection_reply( + conversation_type, + sender_staff_id, + conversation_id, + "DingTalk connection code is invalid or expired.", + ) + return True + + if not sender_staff_id: + await self._send_connection_reply( + conversation_type, + sender_staff_id, + conversation_id, + "DingTalk connection could not be completed from this message.", + ) + return True + + await self._connection_repo.upsert_connection( + owner_user_id=state["owner_user_id"], + provider="dingtalk", + external_account_id=sender_staff_id, + external_account_name=sender_nick or None, + workspace_id=self._connection_workspace_id(conversation_type, conversation_id), + metadata={ + "conversation_type": conversation_type, + "conversation_id": conversation_id, + }, + status="connected", + ) + await self._send_connection_reply( + conversation_type, + sender_staff_id, + conversation_id, + "DingTalk connected to DeerFlow.", + ) + return True + + async def _send_connection_reply( + self, + conversation_type: str, + sender_staff_id: str, + conversation_id: str, + text: str, + ) -> None: + robot_code = self._client_id + if conversation_type == _CONVERSATION_TYPE_GROUP: + if conversation_id: + await self._send_text_message_to_group(robot_code, conversation_id, text) + return + if sender_staff_id: + await self._send_text_message_to_user(robot_code, sender_staff_id, text) + async def _send_running_reply(self, chat_id: str, inbound: InboundMessage) -> None: conversation_type = inbound.metadata.get("conversation_type", _CONVERSATION_TYPE_P2P) sender_staff_id = inbound.metadata.get("sender_staff_id", "") diff --git a/backend/app/channels/discord.py b/backend/app/channels/discord.py index c88eb0239..d81a71fd6 100644 --- a/backend/app/channels/discord.py +++ b/backend/app/channels/discord.py @@ -10,8 +10,9 @@ from pathlib import Path from typing import Any from app.channels.base import Channel -from app.channels.commands import is_known_channel_command -from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment +from app.channels.commands import extract_connect_code, is_known_channel_command +from app.channels.connection_identity import attach_connection_identity +from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment logger = logging.getLogger(__name__) @@ -70,6 +71,7 @@ class DiscordChannel(Channel): self._discord_loop: asyncio.AbstractEventLoop | None = None self._main_loop: asyncio.AbstractEventLoop | None = None self._discord_module = None + self._connection_repo = config.get("connection_repo") async def start(self) -> None: if self._running: @@ -287,6 +289,10 @@ class DiscordChannel(Channel): text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip() # Don't return early if text is empty — still process the mention (e.g., create thread) + connect_code = extract_connect_code(text) + if connect_code and await self._bind_connection_from_connect_code(message, connect_code): + return + # --- Determine thread/channel routing and typing target --- thread_id = None chat_id = None @@ -315,6 +321,7 @@ class DiscordChannel(Channel): }, ) inbound.topic_id = thread_id + inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None) self._publish(inbound) # Start typing indicator in the thread if typing_target: @@ -422,6 +429,7 @@ class DiscordChannel(Channel): }, ) inbound.topic_id = thread_id + inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None) # Start typing indicator in the correct target (thread or channel) if typing_target: @@ -436,6 +444,60 @@ class DiscordChannel(Channel): future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop) future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None) + async def _attach_connection_identity(self, inbound: InboundMessage, guild_id: str | None = None) -> InboundMessage: + return await attach_connection_identity( + inbound, + repo=self._connection_repo, + provider="discord", + workspace_id=guild_id, + fallback_without_workspace=True, + ) + + async def _bind_connection_from_connect_code(self, message, code: str) -> bool: + if self._connection_repo is None or not code: + return False + + state = await self._connection_repo.consume_oauth_state(provider="discord", state=code) + if state is None: + await self._send_connection_reply(message, "Discord connection code is invalid or expired.") + return True + + guild = getattr(message, "guild", None) + channel = getattr(message, "channel", None) + author = getattr(message, "author", None) + user_id = str(getattr(author, "id", "") or "") + if not user_id: + await self._send_connection_reply(message, "Discord connection could not be completed from this message.") + return True + + guild_id = str(getattr(guild, "id", "") or "") or None + await self._connection_repo.upsert_connection( + owner_user_id=state["owner_user_id"], + provider="discord", + external_account_id=user_id, + external_account_name=getattr(author, "display_name", None) or getattr(author, "name", None), + workspace_id=guild_id, + workspace_name=getattr(guild, "name", None) if guild is not None else None, + metadata={ + "guild_id": guild_id, + "channel_id": str(getattr(channel, "id", "") or ""), + }, + status="connected", + ) + await self._send_connection_reply(message, "Discord connected to DeerFlow.") + return True + + @staticmethod + async def _send_connection_reply(message, text: str) -> None: + channel = getattr(message, "channel", None) + send = getattr(channel, "send", None) + if send is None: + return + try: + await send(text) + except Exception: + logger.exception("[Discord] failed to send connection reply") + def _run_client(self) -> None: self._discord_loop = asyncio.new_event_loop() asyncio.set_event_loop(self._discord_loop) diff --git a/backend/app/channels/feishu.py b/backend/app/channels/feishu.py index fddbc7186..fa1c4a5d3 100644 --- a/backend/app/channels/feishu.py +++ b/backend/app/channels/feishu.py @@ -11,7 +11,8 @@ import time from typing import Any, Literal from app.channels.base import Channel -from app.channels.commands import is_known_channel_command +from app.channels.commands import extract_connect_code, is_known_channel_command +from app.channels.connection_identity import attach_connection_identity from app.channels.message_bus import ( PENDING_CLARIFICATION_METADATA_KEY, RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY, @@ -71,6 +72,7 @@ class FeishuChannel(Channel): self._CreateImageRequestBody = None self._GetMessageResourceRequest = None self._thread_lock = threading.Lock() + self._connection_repo = config.get("connection_repo") @staticmethod def _non_empty_str(value: Any) -> str | None: @@ -86,6 +88,23 @@ class FeishuChannel(Channel): def supports_streaming(self) -> bool: return True + @property + def is_running(self) -> bool: + if not self._running: + return False + return self._thread is not None and self._thread.is_alive() + + def _build_event_handler(self, lark): + return ( + lark.EventDispatcherHandler.builder("", "") + .register_p2_im_message_receive_v1(self._on_message) + .register_p2_im_message_message_read_v1(self._on_ignored_message_event) + .register_p2_im_message_reaction_created_v1(self._on_ignored_message_event) + .register_p2_im_message_reaction_deleted_v1(self._on_ignored_message_event) + .register_p2_im_message_recalled_v1(self._on_ignored_message_event) + .build() + ) + async def start(self) -> None: if self._running: return @@ -179,7 +198,7 @@ class FeishuChannel(Channel): # thread's uvloop. _ws_client_mod.loop = loop - event_handler = lark.EventDispatcherHandler.builder("", "").register_p2_im_message_receive_v1(self._on_message).build() + event_handler = self._build_event_handler(lark) ws_client = lark.ws.Client( app_id=app_id, app_secret=app_secret, @@ -191,6 +210,10 @@ class FeishuChannel(Channel): except Exception: if self._running: logger.exception("Feishu WebSocket error") + self._running = False + + def _on_ignored_message_event(self, event) -> None: + logger.debug("[Feishu] ignoring non-content message event: %s", type(event).__name__) async def stop(self) -> None: self._running = False @@ -726,11 +749,47 @@ class FeishuChannel(Channel): async def _prepare_inbound(self, msg_id: str, inbound) -> None: """Kick off Feishu side effects without delaying inbound dispatch.""" + inbound = await self._attach_connection_identity(inbound) reaction_task = asyncio.create_task(self._add_reaction(msg_id, "OK")) self._track_background_task(reaction_task, name="add_reaction", msg_id=msg_id) self._ensure_running_card_started(msg_id) await self.bus.publish_inbound(inbound) + async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage: + return await attach_connection_identity( + inbound, + repo=self._connection_repo, + provider="feishu", + workspace_id=inbound.chat_id, + ) + + async def _bind_connection_from_connect_code(self, *, message_id: str, chat_id: str, user_id: str, code: str) -> bool: + if self._connection_repo is None or not code: + return False + + state = await self._connection_repo.consume_oauth_state(provider="feishu", state=code) + if state is None: + await self._reply_card(message_id, "Feishu connection code is invalid or expired.") + return True + + if not user_id or not chat_id: + await self._reply_card(message_id, "Feishu connection could not be completed from this message.") + return True + + await self._connection_repo.upsert_connection( + owner_user_id=state["owner_user_id"], + provider="feishu", + external_account_id=user_id, + workspace_id=chat_id, + metadata={ + "chat_id": chat_id, + "message_id": message_id, + }, + status="connected", + ) + await self._reply_card(message_id, "Feishu connected to DeerFlow.") + return True + def _on_message(self, event) -> None: """Called by lark-oapi when a message is received (runs in lark thread).""" try: @@ -819,6 +878,23 @@ class FeishuChannel(Channel): logger.info("[Feishu] empty text, ignoring message") return + connect_code = extract_connect_code(text) + if connect_code and self._connection_repo is not None: + if self._main_loop and self._main_loop.is_running(): + fut = asyncio.run_coroutine_threadsafe( + self._bind_connection_from_connect_code( + message_id=msg_id, + chat_id=chat_id, + user_id=sender_id, + code=connect_code, + ), + self._main_loop, + ) + fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid)) + else: + logger.warning("[Feishu] main loop not running, cannot bind channel connection") + return + # Only treat known slash commands as commands; absolute paths and # other slash-prefixed text should be handled as normal chat. if _is_feishu_command(text): diff --git a/backend/app/channels/manager.py b/backend/app/channels/manager.py index 673723d6e..b3cd23765 100644 --- a/backend/app/channels/manager.py +++ b/backend/app/channels/manager.py @@ -274,6 +274,22 @@ def _response_metadata(base_metadata: dict[str, Any], *, pending_clarification: return metadata +def _thread_channel_metadata(msg: InboundMessage) -> dict[str, Any]: + channel_source: dict[str, Any] = { + "type": "im_channel", + "provider": msg.channel_name, + "chat_id": msg.chat_id, + } + if msg.topic_id: + channel_source["topic_id"] = msg.topic_id + if msg.thread_ts: + channel_source["thread_ts"] = msg.thread_ts + if msg.connection_id: + channel_source["connection_id"] = msg.connection_id + + return {"channel_source": channel_source} + + def _extract_text_content(content: Any) -> str: """Extract text from a streaming payload content field.""" if isinstance(content, str): @@ -440,6 +456,43 @@ def _human_input_message(content: str, *, original_content: str | None = None) - return message +def _auth_disabled_owner_user_id() -> str | None: + try: + from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled + except Exception: + logger.debug("Unable to inspect auth-disabled mode for channel owner fallback", exc_info=True) + return None + return AUTH_DISABLED_USER_ID if is_auth_disabled() else None + + +def _effective_owner_user_id(msg: InboundMessage) -> str | None: + return _auth_disabled_owner_user_id() or msg.owner_user_id + + +def _apply_effective_owner(msg: InboundMessage) -> InboundMessage: + owner_user_id = _effective_owner_user_id(msg) + if owner_user_id: + msg.owner_user_id = owner_user_id + return msg + + +def _owner_headers(msg: InboundMessage) -> dict[str, str] | None: + owner_user_id = _effective_owner_user_id(msg) + if not owner_user_id: + return None + return create_internal_auth_headers(owner_user_id=owner_user_id) + + +def _safe_user_id_for_run(raw_user_id: str) -> str: + from deerflow.config.paths import get_paths + + try: + return get_paths().prepare_user_dir_for_raw_id(raw_user_id) + except Exception: + logger.exception("Failed to prepare channel run user directory") + return make_safe_user_id(raw_user_id) + + def _resolve_slash_skill_command( text: str, available_skills: set[str] | None = None, @@ -670,6 +723,7 @@ class ChannelManager: assistant_id: str = DEFAULT_ASSISTANT_ID, default_session: dict[str, Any] | None = None, channel_sessions: dict[str, Any] | None = None, + connection_repo: Any | None = None, ) -> None: self.bus = bus self.store = store @@ -679,7 +733,9 @@ class ChannelManager: self._assistant_id = assistant_id self._default_session = _as_dict(default_session) self._channel_sessions = dict(channel_sessions or {}) + self._connection_repo = connection_repo self._client = None # lazy init — langgraph_sdk async client + self._channel_metadata_synced: set[str] = set() self._skill_storage: SkillStorage | None = None self._csrf_token = generate_csrf_token() self._semaphore: asyncio.Semaphore | None = None @@ -728,12 +784,17 @@ class ChannelManager: configurable["checkpoint_ns"] = "" configurable["thread_id"] = thread_id - # ``user_id`` drives user-scoped filesystem buckets that only accept - # ``[A-Za-z0-9_-]``, so normalize the channel id and keep the raw value - # under ``channel_user_id`` for platform-facing lookups. + # ``user_id`` drives DeerFlow-owned memory, files, and thread buckets. + # For browser-connected IM channels, prefer the DeerFlow account that + # owns the connection. Preserve the raw platform user under + # ``channel_user_id`` for platform-facing lookups and audits. run_context_identity: dict[str, Any] = {"thread_id": thread_id} + owner_user_id = _effective_owner_user_id(msg) + if owner_user_id: + run_context_identity["user_id"] = _safe_user_id_for_run(owner_user_id) + elif msg.user_id: + run_context_identity["user_id"] = _safe_user_id_for_run(msg.user_id) if msg.user_id: - run_context_identity["user_id"] = make_safe_user_id(msg.user_id) run_context_identity["channel_user_id"] = msg.user_id run_context = _merge_dicts( @@ -845,6 +906,7 @@ class ChannelManager: logger.error("[Manager] unhandled error in message task: %s", exc, exc_info=exc) async def _handle_message(self, msg: InboundMessage) -> None: + msg = _apply_effective_owner(msg) async with self._semaphore: try: if msg.msg_type == InboundMessageType.COMMAND: @@ -877,10 +939,27 @@ class ChannelManager: # -- chat handling ----------------------------------------------------- - async def _create_thread(self, client, msg: InboundMessage) -> str: - """Create a new thread through Gateway and store the mapping.""" - thread = await client.threads.create() - thread_id = thread["thread_id"] + async def _lookup_thread_id(self, msg: InboundMessage) -> str | None: + if msg.connection_id and self._connection_repo is not None: + return await self._connection_repo.get_thread_id( + msg.connection_id, + msg.chat_id, + msg.topic_id, + ) + return self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) + + async def _store_thread_id(self, msg: InboundMessage, thread_id: str) -> None: + if msg.connection_id and msg.owner_user_id and self._connection_repo is not None: + await self._connection_repo.set_thread_id( + connection_id=msg.connection_id, + owner_user_id=msg.owner_user_id, + provider=msg.channel_name, + external_conversation_id=msg.chat_id, + external_topic_id=msg.topic_id, + thread_id=thread_id, + ) + return + self.store.set_thread_id( msg.channel_name, msg.chat_id, @@ -888,18 +967,49 @@ class ChannelManager: topic_id=msg.topic_id, user_id=msg.user_id, ) + + async def _create_thread(self, client, msg: InboundMessage) -> str: + """Create a new thread through Gateway and store the mapping.""" + metadata = _thread_channel_metadata(msg) + owner_headers = _owner_headers(msg) + if owner_headers: + thread = await client.threads.create(metadata=metadata, headers=owner_headers) + else: + thread = await client.threads.create(metadata=metadata) + thread_id = thread["thread_id"] + await self._store_thread_id(msg, thread_id) logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id) return thread_id + async def _update_thread_channel_metadata(self, client, msg: InboundMessage, thread_id: str) -> None: + """Best-effort source metadata backfill for existing IM-created threads.""" + # The metadata (provider/chat/topic) is constant for a thread, so one + # successful backfill per manager lifetime is enough — skip the + # redundant PATCH on every subsequent inbound message. + if thread_id in self._channel_metadata_synced: + return + update_kwargs: dict[str, Any] = {"metadata": _thread_channel_metadata(msg)} + if owner_headers := _owner_headers(msg): + update_kwargs["headers"] = owner_headers + try: + await client.threads.update(thread_id, **update_kwargs) + except Exception: + logger.debug("[Manager] failed to update channel metadata for thread_id=%s", thread_id, exc_info=True) + return + if len(self._channel_metadata_synced) > 4096: + self._channel_metadata_synced.clear() + self._channel_metadata_synced.add(thread_id) + async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None: client = self._get_client() # Look up existing DeerFlow thread. # topic_id may be None (e.g. Telegram private chats) — the store # handles this by using the "channel:chat_id" key without a topic suffix. - thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) + thread_id = await self._lookup_thread_id(msg) if thread_id: logger.info("[Manager] reusing thread: thread_id=%s for topic_id=%s", thread_id, msg.topic_id) + await self._update_thread_channel_metadata(client, msg, thread_id) # No existing thread found — create a new one if thread_id is None: @@ -940,14 +1050,19 @@ class ChannelManager: return logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100]) + run_kwargs: dict[str, Any] = { + "input": {"messages": [human_message]}, + "config": run_config, + "context": run_context, + "multitask_strategy": "reject", + } + if owner_headers := _owner_headers(msg): + run_kwargs["headers"] = owner_headers try: result = await client.runs.wait( thread_id, assistant_id, - input={"messages": [human_message]}, - config=run_config, - context=run_context, - multitask_strategy="reject", + **run_kwargs, ) except Exception as exc: if _is_thread_busy_error(exc): @@ -984,6 +1099,8 @@ class ChannelManager: artifacts=artifacts, attachments=attachments, thread_ts=msg.thread_ts, + connection_id=msg.connection_id, + owner_user_id=msg.owner_user_id, metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification), ) logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id) @@ -1008,16 +1125,21 @@ class ChannelManager: last_published_text = "" last_publish_at = 0.0 stream_error: BaseException | None = None + stream_kwargs: dict[str, Any] = { + "input": {"messages": [human_message]}, + "config": run_config, + "context": run_context, + "stream_mode": ["messages-tuple", "values"], + "multitask_strategy": "reject", + } + if owner_headers := _owner_headers(msg): + stream_kwargs["headers"] = owner_headers try: async for chunk in client.runs.stream( thread_id, assistant_id, - input={"messages": [human_message]}, - config=run_config, - context=run_context, - stream_mode=["messages-tuple", "values"], - multitask_strategy="reject", + **stream_kwargs, ): event = getattr(chunk, "event", "") data = getattr(chunk, "data", None) @@ -1047,6 +1169,8 @@ class ChannelManager: text=latest_text, is_final=False, thread_ts=msg.thread_ts, + connection_id=msg.connection_id, + owner_user_id=msg.owner_user_id, metadata=_response_metadata(msg.metadata), ) ) @@ -1093,6 +1217,8 @@ class ChannelManager: attachments=attachments, is_final=True, thread_ts=msg.thread_ts, + connection_id=msg.connection_id, + owner_user_id=msg.owner_user_id, metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification), ) ) @@ -1124,18 +1250,10 @@ class ChannelManager: if reply is None and command == "new": # Create a new thread through Gateway client = self._get_client() - thread = await client.threads.create() - new_thread_id = thread["thread_id"] - self.store.set_thread_id( - msg.channel_name, - msg.chat_id, - new_thread_id, - topic_id=msg.topic_id, - user_id=msg.user_id, - ) + await self._create_thread(client, msg) reply = "New conversation started." elif reply is None and command == "status": - thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) + thread_id = await self._lookup_thread_id(msg) reply = f"Active thread: {thread_id}" if thread_id else "No active conversation." elif reply is None and command == "models": reply = await self._fetch_gateway("/api/models", "models") @@ -1174,9 +1292,11 @@ class ChannelManager: outbound = OutboundMessage( channel_name=msg.channel_name, chat_id=msg.chat_id, - thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "", + thread_id=await self._lookup_thread_id(msg) or "", text=reply, thread_ts=msg.thread_ts, + connection_id=msg.connection_id, + owner_user_id=msg.owner_user_id, metadata=_slim_metadata(msg.metadata), ) await self.bus.publish_outbound(outbound) @@ -1212,9 +1332,11 @@ class ChannelManager: outbound = OutboundMessage( channel_name=msg.channel_name, chat_id=msg.chat_id, - thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "", + thread_id=await self._lookup_thread_id(msg) or "", text=error_text, thread_ts=msg.thread_ts, + connection_id=msg.connection_id, + owner_user_id=msg.owner_user_id, metadata=_slim_metadata(msg.metadata), ) await self.bus.publish_outbound(outbound) diff --git a/backend/app/channels/message_bus.py b/backend/app/channels/message_bus.py index 4e847cca0..64a3c2271 100644 --- a/backend/app/channels/message_bus.py +++ b/backend/app/channels/message_bus.py @@ -44,6 +44,12 @@ class InboundMessage: Messages sharing the same ``topic_id`` within a ``chat_id`` will reuse the same DeerFlow thread. When ``None``, each message creates a new thread (one-shot Q&A). + connection_id: Optional DeerFlow channel connection id. When present, + conversation mapping is scoped by the connection instead of the + legacy global ``channel_name:chat_id[:topic_id]`` key. + owner_user_id: DeerFlow user id that owns the channel connection. + Platform user ids stay in ``user_id``. + workspace_id: Optional external workspace/guild/team id. files: Optional list of file attachments (platform-specific dicts). metadata: Arbitrary extra data from the channel. created_at: Unix timestamp when the message was created. @@ -56,6 +62,9 @@ class InboundMessage: msg_type: InboundMessageType = InboundMessageType.CHAT thread_ts: str | None = None topic_id: str | None = None + connection_id: str | None = None + owner_user_id: str | None = None + workspace_id: str | None = None files: list[dict[str, Any]] = field(default_factory=list) metadata: dict[str, Any] = field(default_factory=dict) created_at: float = field(default_factory=time.time) @@ -95,6 +104,9 @@ class OutboundMessage: is_final: Whether this is the final message in the response stream. thread_ts: Optional platform thread identifier for threaded replies. metadata: Arbitrary extra data. + connection_id: Optional DeerFlow channel connection id used for + connection-specific outbound credentials. + owner_user_id: DeerFlow user id that owns the channel connection. created_at: Unix timestamp. """ @@ -106,6 +118,8 @@ class OutboundMessage: attachments: list[ResolvedAttachment] = field(default_factory=list) is_final: bool = True thread_ts: str | None = None + connection_id: str | None = None + owner_user_id: str | None = None metadata: dict[str, Any] = field(default_factory=dict) created_at: float = field(default_factory=time.time) diff --git a/backend/app/channels/runtime_config_store.py b/backend/app/channels/runtime_config_store.py new file mode 100644 index 000000000..497e40623 --- /dev/null +++ b/backend/app/channels/runtime_config_store.py @@ -0,0 +1,154 @@ +"""Local persistence for runtime IM channel configuration.""" + +from __future__ import annotations + +import json +import logging +import tempfile +import threading +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + +RUNTIME_CHANNEL_DISABLED_FLAG = "_runtime_disabled" + + +class ChannelRuntimeConfigStore: + """JSON-backed store for channel credentials entered from the UI. + + This intentionally mirrors ``ChannelStore``: local/private deployments get + durable runtime configuration without needing a public callback URL or a + config.yaml edit. + """ + + def __init__(self, path: str | Path | None = None) -> None: + if path is None: + from deerflow.config.paths import get_paths + + path = Path(get_paths().base_dir) / "channels" / "runtime-config.json" + self._path = Path(path) + self._path.parent.mkdir(parents=True, exist_ok=True) + self._data: dict[str, dict[str, Any]] = self._load() + self._lock = threading.Lock() + + def _load(self) -> dict[str, dict[str, Any]]: + if self._path.exists(): + try: + raw = json.loads(self._path.read_text(encoding="utf-8")) + except (json.JSONDecodeError, OSError): + logger.warning("Corrupt channel runtime config store at %s, starting fresh", self._path) + return {} + if isinstance(raw, dict): + return {str(name): dict(value) for name, value in raw.items() if isinstance(value, dict)} + return {} + + def _save(self) -> None: + fd = tempfile.NamedTemporaryFile( + mode="w", + dir=self._path.parent, + suffix=".tmp", + delete=False, + ) + try: + json.dump(self._data, fd, indent=2, ensure_ascii=False) + fd.close() + Path(fd.name).replace(self._path) + try: + self._path.chmod(0o600) + except OSError: + logger.debug("Unable to chmod channel runtime config store at %s", self._path, exc_info=True) + except BaseException: + fd.close() + Path(fd.name).unlink(missing_ok=True) + raise + + def load_all(self) -> dict[str, dict[str, Any]]: + with self._lock: + return {name: dict(config) for name, config in self._data.items()} + + def get_provider_config(self, provider: str) -> dict[str, Any] | None: + with self._lock: + config = self._data.get(provider) + return dict(config) if isinstance(config, dict) else None + + def set_provider_config(self, provider: str, config: dict[str, Any]) -> None: + with self._lock: + self._data[provider] = dict(config) + self._save() + + def set_provider_disconnected(self, provider: str) -> None: + with self._lock: + self._data[provider] = { + "enabled": False, + RUNTIME_CHANNEL_DISABLED_FLAG: True, + } + self._save() + + def remove_provider_config(self, provider: str) -> bool: + with self._lock: + if provider not in self._data: + return False + del self._data[provider] + self._save() + return True + + +def _provider_enabled(channel_connections_config: Any, provider: str) -> bool: + provider_config = getattr(channel_connections_config, provider, None) + return bool(getattr(provider_config, "enabled", False)) + + +def _runtime_channel_disconnected(runtime_config: dict[str, Any]) -> bool: + return runtime_config.get(RUNTIME_CHANNEL_DISABLED_FLAG) is True and runtime_config.get("enabled") is False + + +def merge_runtime_channel_configs( + channels_config: dict[str, Any], + channel_connections_config: Any, + *, + store: ChannelRuntimeConfigStore | None = None, +) -> None: + """Merge persisted runtime provider config into ``channels_config`` in-place.""" + if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False): + return + + runtime_store = store or ChannelRuntimeConfigStore() + for provider, runtime_config in runtime_store.load_all().items(): + if not _provider_enabled(channel_connections_config, provider): + continue + if _runtime_channel_disconnected(runtime_config): + channels_config.pop(provider, None) + continue + existing = channels_config.get(provider) + merged = dict(runtime_config) + if isinstance(existing, dict): + merged.update(existing) + channels_config[provider] = merged + + +def apply_runtime_connection_config( + channel_connections_config: Any, + *, + store: ChannelRuntimeConfigStore | None = None, +) -> Any: + """Apply persisted connection metadata that lives outside ``channels``. + + Telegram uses a bot username for deep links; UI-entered values are stored + with the runtime channel config so local restarts keep the provider + configured. + """ + if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False): + return channel_connections_config + + runtime_store = store or ChannelRuntimeConfigStore() + telegram_runtime_config = runtime_store.get_provider_config("telegram") + bot_username = "" + if isinstance(telegram_runtime_config, dict): + bot_username = str(telegram_runtime_config.get("bot_username") or "").strip() + if not bot_username or not _provider_enabled(channel_connections_config, "telegram"): + return channel_connections_config + + config = channel_connections_config.model_copy(deep=True) + config.telegram.bot_username = bot_username + return config diff --git a/backend/app/channels/service.py b/backend/app/channels/service.py index f7bc7eaa0..a222c5a01 100644 --- a/backend/app/channels/service.py +++ b/backend/app/channels/service.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import logging import os from typing import TYPE_CHECKING, Any @@ -9,6 +10,7 @@ from typing import TYPE_CHECKING, Any from app.channels.base import Channel from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager from app.channels.message_bus import MessageBus +from app.channels.runtime_config_store import merge_runtime_channel_configs from app.channels.store import ChannelStore logger = logging.getLogger(__name__) @@ -42,6 +44,11 @@ _CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL" _CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL" +def _channel_has_credentials(name: str, channel_config: dict[str, Any]) -> bool: + cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, []) + return any(not isinstance(channel_config.get(key), bool) and channel_config.get(key) is not None and str(channel_config[key]).strip() for key in cred_keys) + + def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str, default: str) -> str: value = config.pop(config_key, None) if isinstance(value, str) and value.strip(): @@ -52,6 +59,30 @@ def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str, return default +def _merge_channel_connection_runtime_config(channels_config: dict[str, Any], app_config: AppConfig) -> None: + connection_config = getattr(app_config, "channel_connections", None) + merge_runtime_channel_configs(channels_config, connection_config) + + +def _make_connection_repo(app_config: AppConfig): + connection_config = getattr(app_config, "channel_connections", None) + if connection_config is None or not getattr(connection_config, "enabled", False): + return None + + try: + from deerflow.persistence.channel_connections import ChannelConnectionRepository + from deerflow.persistence.engine import get_session_factory + except Exception: + logger.exception("Failed to import channel connection repository") + return None + + session_factory = get_session_factory() + if session_factory is None: + logger.warning("Channel connections are enabled but database persistence is not available") + return None + return ChannelConnectionRepository(session_factory) + + class ChannelService: """Manages the lifecycle of all configured IM channels. @@ -59,9 +90,10 @@ class ChannelService: instantiates enabled channels, and starts the ChannelManager dispatcher. """ - def __init__(self, channels_config: dict[str, Any] | None = None) -> None: + def __init__(self, channels_config: dict[str, Any] | None = None, *, connection_repo: Any | None = None) -> None: self.bus = MessageBus() self.store = ChannelStore() + self._connection_repo = connection_repo config = dict(channels_config or {}) langgraph_url = _resolve_service_url(config, "langgraph_url", _CHANNELS_LANGGRAPH_URL_ENV, DEFAULT_LANGGRAPH_URL) gateway_url = _resolve_service_url(config, "gateway_url", _CHANNELS_GATEWAY_URL_ENV, DEFAULT_GATEWAY_URL) @@ -74,10 +106,12 @@ class ChannelService: gateway_url=gateway_url, default_session=default_session if isinstance(default_session, dict) else None, channel_sessions=channel_sessions, + connection_repo=connection_repo, ) self._channels: dict[str, Any] = {} # name -> Channel instance self._config = config self._running = False + self._readiness_locks: dict[str, asyncio.Lock] = {} @classmethod def from_app_config(cls, app_config: AppConfig | None = None) -> ChannelService: @@ -90,8 +124,9 @@ class ChannelService: # extra fields are allowed by AppConfig (extra="allow") extra = app_config.model_extra or {} if "channels" in extra: - channels_config = extra["channels"] - return cls(channels_config=channels_config) + channels_config = dict(extra["channels"] or {}) + _merge_channel_connection_runtime_config(channels_config, app_config) + return cls(channels_config=channels_config, connection_repo=_make_connection_repo(app_config)) async def start(self) -> None: """Start the manager and all enabled channels.""" @@ -99,36 +134,83 @@ class ChannelService: return await self.manager.start() + self._running = True + ready_status = await self.ensure_ready_channels(attempts=2) + ready_count = sum(1 for ready in ready_status.values() if ready) + logger.info("ChannelService started with %d/%d ready channels", ready_count, len(ready_status)) + + async def ensure_ready_channels(self, *, attempts: int = 1) -> dict[str, bool]: + """Start or restart enabled configured channels that are not ready.""" + ready_status: dict[str, bool] = {} for name, channel_config in self._config.items(): if not isinstance(channel_config, dict): continue if not channel_config.get("enabled", False): - cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, []) - has_creds = any(not isinstance(channel_config.get(k), bool) and channel_config.get(k) is not None and str(channel_config[k]).strip() for k in cred_keys) - if has_creds: + if _channel_has_credentials(name, channel_config): logger.warning( - "Channel '%s' has credentials configured but is disabled. Set enabled: true under channels.%s in config.yaml to activate it.", - name, - name, + "A configured channel has credentials configured but is disabled. Set enabled: true under its channels entry in config.yaml to activate it.", ) else: - logger.info("Channel %s is disabled, skipping", name) + logger.info("A configured channel is disabled, skipping") continue - await self._start_channel(name, channel_config) + ready_status[name] = await self.ensure_channel_ready(name, attempts=attempts) + return ready_status - self._running = True - logger.info("ChannelService started with channels: %s", list(self._channels.keys())) + async def ensure_channel_ready( + self, + name: str, + config: dict[str, Any] | None = None, + *, + attempts: int = 1, + ) -> bool: + """Ensure a single enabled channel is running using its current config.""" + if not self._running: + logger.warning("ChannelService is not running; cannot ensure channel readiness") + return False + + if config is not None: + self._config[name] = dict(config) + + # Serialize per channel: readiness is polled from request handlers, so + # concurrent calls must not stop/start the same channel worker twice. + lock = self._readiness_locks.setdefault(name, asyncio.Lock()) + async with lock: + channel_config = self._config.get(name) + if not channel_config or not isinstance(channel_config, dict): + logger.warning("No config for requested channel") + return False + if not channel_config.get("enabled", False): + return False + + channel = self._channels.get(name) + if channel is not None and channel.is_running: + return True + + if channel is not None: + try: + await channel.stop() + except Exception: + logger.exception("Error stopping non-running channel before readiness retry") + self._channels.pop(name, None) + + max_attempts = max(1, attempts) + for attempt in range(max_attempts): + if attempt > 0: + logger.info("Retrying channel startup after readiness check") + if await self._start_channel(name, channel_config): + return True + return False async def stop(self) -> None: """Stop all channels and the manager.""" for name, channel in list(self._channels.items()): try: await channel.stop() - logger.info("Channel %s stopped", name) + logger.info("Channel stopped") except Exception: - logger.exception("Error stopping channel %s", name) + logger.exception("Error stopping channel") self._channels.clear() await self.manager.stop() @@ -140,6 +222,9 @@ class ChannelService: Uses ``get_app_config()`` which detects file changes via mtime, so edits to ``config.yaml`` are picked up without a process restart. + The UI runtime-config overlay applied at startup is re-applied here + so a file-driven reload neither drops credentials entered from the + browser nor resurrects a channel disconnected from it. Falls back to the cached ``self._config`` when config loading fails. """ try: @@ -147,7 +232,8 @@ class ChannelService: app_config = get_app_config() extra = app_config.model_extra or {} - channels_config = extra.get("channels", {}) + channels_config = dict(extra.get("channels") or {}) + _merge_channel_connection_runtime_config(channels_config, app_config) channel_config = channels_config.get(name) if isinstance(channel_config, dict): # Update the cached config so get_status() stays consistent. @@ -157,18 +243,23 @@ class ChannelService: logger.exception("Failed to reload config for channel %s, using cached version", name) return self._config.get(name) - async def restart_channel(self, name: str) -> bool: + async def restart_channel(self, name: str, *, reload_config: bool = True) -> bool: """Restart a specific channel. Returns True if successful.""" if name in self._channels: try: await self._channels[name].stop() except Exception: - logger.exception("Error stopping channel %s for restart", name) + logger.exception("Error stopping channel for restart") del self._channels[name] - config = self._load_channel_config(name) + if reload_config: + # Reading config.yaml and the runtime store is disk IO; keep it + # off the event loop. + config = await asyncio.to_thread(self._load_channel_config, name) + else: + config = self._config.get(name) if not config or not isinstance(config, dict): - logger.warning("No config for channel %s", name) + logger.warning("No config for requested channel") return False if not config.get("enabled", False): @@ -177,11 +268,35 @@ class ChannelService: return await self._start_channel(name, config) + async def configure_channel(self, name: str, config: dict[str, Any]) -> bool: + """Apply runtime config for a channel and restart it if the service is running.""" + self._config[name] = dict(config) + if not self._running: + return True + # The caller just supplied the authoritative config (e.g. credentials + # entered in the browser that are never written to config.yaml) — a + # file reload here would clobber it with the stale on-disk entry. + return await self.restart_channel(name, reload_config=False) + + async def remove_channel(self, name: str) -> bool: + """Remove runtime config for a channel and stop it if currently running.""" + self._config.pop(name, None) + channel = self._channels.pop(name, None) + if channel is None: + return True + try: + await channel.stop() + logger.info("Channel stopped and removed") + return True + except Exception: + logger.exception("Error stopping channel for removal") + return False + async def _start_channel(self, name: str, config: dict[str, Any]) -> bool: """Instantiate and start a single channel.""" import_path = _CHANNEL_REGISTRY.get(name) if not import_path: - logger.warning("Unknown channel type: %s", name) + logger.warning("Unknown channel type") return False try: @@ -189,24 +304,26 @@ class ChannelService: channel_cls = resolve_class(import_path, base_class=None) except Exception: - logger.exception("Failed to import channel class for %s", name) + logger.exception("Failed to import channel class") return False try: config = dict(config) config["channel_store"] = self.store + if self._connection_repo is not None: + config["connection_repo"] = self._connection_repo channel = channel_cls(bus=self.bus, config=config) self._channels[name] = channel await channel.start() if not channel.is_running: self._channels.pop(name, None) - logger.error("Channel %s did not enter a running state after start()", name) + logger.error("Channel did not enter a running state after start()") return False - logger.info("Channel %s started", name) + logger.info("Channel started") return True except Exception: self._channels.pop(name, None) - logger.exception("Failed to start channel %s", name) + logger.exception("Failed to start channel") return False def get_status(self) -> dict[str, Any]: @@ -245,7 +362,9 @@ async def start_channel_service(app_config: AppConfig | None = None) -> ChannelS global _channel_service if _channel_service is not None: return _channel_service - _channel_service = ChannelService.from_app_config(app_config) + # from_app_config reads the JSON channel store and runtime config files; + # keep that disk IO off the event loop. + _channel_service = await asyncio.to_thread(ChannelService.from_app_config, app_config) await _channel_service.start() return _channel_service diff --git a/backend/app/channels/slack.py b/backend/app/channels/slack.py index 3e31a19b2..cfe03c50c 100644 --- a/backend/app/channels/slack.py +++ b/backend/app/channels/slack.py @@ -9,7 +9,8 @@ from typing import Any from markdown_to_mrkdwn import SlackMarkdownConverter from app.channels.base import Channel -from app.channels.commands import is_known_channel_command +from app.channels.commands import extract_connect_code, is_known_channel_command +from app.channels.connection_identity import attach_connection_identity from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment logger = logging.getLogger(__name__) @@ -64,6 +65,9 @@ class SlackChannel(Channel): self._web_client = None self._loop: asyncio.AbstractEventLoop | None = None self._allowed_users = _normalize_allowed_users(config.get("allowed_users", [])) + self._connection_repo = config.get("connection_repo") + self._web_client_factory = config.get("web_client_factory") + self._connection_web_clients: dict[str, tuple[str, Any]] = {} configured_bot_user_id = config.get("bot_user_id") self._bot_user_id = str(configured_bot_user_id).lstrip("@") if configured_bot_user_id else None @@ -80,26 +84,28 @@ class SlackChannel(Channel): return self._SocketModeResponse = SocketModeResponse + if self._web_client_factory is None: + self._web_client_factory = WebClient bot_token = self.config.get("bot_token", "") app_token = self.config.get("app_token", "") + if self._connection_repo is not None and self.config.get("event_delivery") == "http": + if not bot_token: + logger.error("Slack HTTP Events mode requires bot_token") + return + await self._initialize_operator_web_client(str(bot_token)) + self._loop = asyncio.get_event_loop() + self._running = True + self.bus.subscribe_outbound(self._on_outbound) + logger.info("Slack channel started in HTTP Events mode") + return + if not bot_token or not app_token: logger.error("Slack channel requires bot_token and app_token") return - self._web_client = WebClient(token=bot_token) - if self._bot_user_id is None: - try: - auth_info = await asyncio.to_thread(self._web_client.auth_test) - user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None - if user_id is None: - auth_get = getattr(auth_info, "get", None) - user_id = auth_get("user_id") if callable(auth_get) else None - if isinstance(user_id, str) and user_id: - self._bot_user_id = user_id - except Exception: - logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True) + await self._initialize_operator_web_client(str(bot_token)) self._socket_client = SocketModeClient( app_token=app_token, web_client=self._web_client, @@ -124,7 +130,8 @@ class SlackChannel(Channel): logger.info("Slack channel stopped") async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None: - if not self._web_client: + web_client = await self._get_web_client_for_message(msg) + if not web_client: return kwargs: dict[str, Any] = { @@ -137,11 +144,12 @@ class SlackChannel(Channel): last_exc: Exception | None = None for attempt in range(_max_retries): try: - await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs) + await asyncio.to_thread(web_client.chat_postMessage, **kwargs) # Add a completion reaction to the thread root if msg.thread_ts: await asyncio.to_thread( - self._add_reaction, + self._add_reaction_with_client, + web_client, msg.chat_id, msg.thread_ts, "white_check_mark", @@ -165,7 +173,8 @@ class SlackChannel(Channel): if msg.thread_ts: try: await asyncio.to_thread( - self._add_reaction, + self._add_reaction_with_client, + web_client, msg.chat_id, msg.thread_ts, "x", @@ -177,7 +186,8 @@ class SlackChannel(Channel): raise last_exc async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool: - if not self._web_client: + web_client = await self._get_web_client_for_message(msg) + if not web_client: return False try: @@ -190,7 +200,7 @@ class SlackChannel(Channel): if msg.thread_ts: kwargs["thread_ts"] = msg.thread_ts - await asyncio.to_thread(self._web_client.files_upload_v2, **kwargs) + await asyncio.to_thread(web_client.files_upload_v2, **kwargs) logger.info("[Slack] file uploaded: %s to channel=%s", attachment.filename, msg.chat_id) return True except Exception: @@ -199,12 +209,45 @@ class SlackChannel(Channel): # -- internal ---------------------------------------------------------- - def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None: - """Add an emoji reaction to a message (best-effort, non-blocking).""" - if not self._web_client: + async def _initialize_operator_web_client(self, bot_token: str) -> None: + self._web_client = self._web_client_factory(token=bot_token) + if self._bot_user_id is not None: return try: - self._web_client.reactions_add( + auth_info = await asyncio.to_thread(self._web_client.auth_test) + user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None + if user_id is None: + auth_get = getattr(auth_info, "get", None) + user_id = auth_get("user_id") if callable(auth_get) else None + if isinstance(user_id, str) and user_id: + self._bot_user_id = user_id + except Exception: + logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True) + + async def _get_web_client_for_message(self, msg: OutboundMessage): + if msg.connection_id and self._connection_repo is not None: + credentials = await self._connection_repo.get_credentials(msg.connection_id) + access_token = credentials.get("access_token") if credentials else None + if not access_token: + return self._web_client + # WebClient keeps its own HTTP session and rate-limit state, so + # reuse one per connection until its token changes. + cached = self._connection_web_clients.get(msg.connection_id) + if cached is not None and cached[0] == access_token: + return cached[1] + if self._web_client_factory is None: + from slack_sdk import WebClient + + self._web_client_factory = WebClient + web_client = self._web_client_factory(token=access_token) + self._connection_web_clients[msg.connection_id] = (access_token, web_client) + return web_client + return self._web_client + + @staticmethod + def _add_reaction_with_client(web_client, channel_id: str, timestamp: str, emoji: str) -> None: + try: + web_client.reactions_add( channel=channel_id, timestamp=timestamp, name=emoji, @@ -213,6 +256,12 @@ class SlackChannel(Channel): if "already_reacted" not in str(exc): logger.warning("[Slack] failed to add reaction %s: %s", emoji, exc) + def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None: + """Add an emoji reaction to a message (best-effort, non-blocking).""" + if not self._web_client: + return + self._add_reaction_with_client(self._web_client, channel_id, timestamp, emoji) + def _send_running_reply(self, channel_id: str, thread_ts: str) -> None: """Send a 'Working on it......' reply in the thread (called from SDK thread).""" if not self._web_client: @@ -249,12 +298,15 @@ class SlackChannel(Channel): # Handle message events (DM or @mention) if etype in ("message", "app_mention"): - self._handle_message_event(event) + self._handle_message_event( + event, + team_id=req.payload.get("team_id") or req.payload.get("team") or event.get("team"), + ) except Exception: logger.exception("Error processing Slack event") - def _handle_message_event(self, event: dict) -> None: + def _handle_message_event(self, event: dict, *, team_id: str | None = None) -> None: # Ignore bot messages if event.get("bot_id") or event.get("subtype"): return @@ -272,6 +324,19 @@ class SlackChannel(Channel): if not text: return + connect_code = extract_connect_code(text) + if connect_code: + if self._loop and self._loop.is_running(): + asyncio.run_coroutine_threadsafe( + self._bind_connection_from_connect_code( + event=event, + team_id=str(team_id or event.get("team") or ""), + code=connect_code, + ), + self._loop, + ) + return + channel_id = event.get("channel", "") thread_ts = event.get("thread_ts") or event.get("ts", "") @@ -297,4 +362,61 @@ class SlackChannel(Channel): self._add_reaction(channel_id, event.get("ts", thread_ts), "eyes") # Send "running" reply first (fire-and-forget from SDK thread) self._send_running_reply(channel_id, thread_ts) - asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop) + if self._connection_repo is None: + asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop) + else: + asyncio.run_coroutine_threadsafe(self._publish_inbound_with_connection(inbound, team_id=team_id), self._loop) + + async def _publish_inbound_with_connection(self, inbound, *, team_id: str | None = None) -> None: + inbound = await self._attach_connection_identity(inbound, team_id=team_id) + await self.bus.publish_inbound(inbound) + + async def _attach_connection_identity(self, inbound, *, team_id: str | None = None): + workspace_id = str(team_id or inbound.metadata.get("team_id") or "") + return await attach_connection_identity( + inbound, + repo=self._connection_repo, + provider="slack", + workspace_id=workspace_id, + ) + + async def _bind_connection_from_connect_code(self, *, event: dict, team_id: str, code: str) -> bool: + if self._connection_repo is None or not code: + return False + + channel_id = str(event.get("channel") or "") + thread_ts = str(event.get("thread_ts") or event.get("ts") or "") + state = await self._connection_repo.consume_oauth_state(provider="slack", state=code) + if state is None: + await self._post_connection_reply(channel_id, "Slack connection code is invalid or expired.", thread_ts) + return True + + user_id = str(event.get("user") or "") + if not user_id or not team_id: + await self._post_connection_reply(channel_id, "Slack connection could not be completed from this message.", thread_ts) + return True + + await self._connection_repo.upsert_connection( + owner_user_id=state["owner_user_id"], + provider="slack", + external_account_id=user_id, + workspace_id=team_id, + metadata={ + "team_id": team_id, + "channel_id": channel_id, + }, + status="connected", + ) + await self._post_connection_reply(channel_id, "Slack connected to DeerFlow.", thread_ts) + return True + + async def _post_connection_reply(self, channel_id: str, text: str, thread_ts: str | None = None) -> None: + if not self._web_client or not channel_id: + return + kwargs: dict[str, Any] = {"channel": channel_id, "text": text} + if thread_ts: + kwargs["thread_ts"] = thread_ts + try: + await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs) + except Exception: + logger.exception("[Slack] failed to send connection reply in channel=%s", channel_id) diff --git a/backend/app/channels/telegram.py b/backend/app/channels/telegram.py index fabdbfb61..cafcb6692 100644 --- a/backend/app/channels/telegram.py +++ b/backend/app/channels/telegram.py @@ -8,6 +8,7 @@ import threading from typing import Any from app.channels.base import Channel +from app.channels.connection_identity import attach_connection_identity from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment logger = logging.getLogger(__name__) @@ -35,6 +36,7 @@ class TelegramChannel(Channel): pass # chat_id -> last sent message_id for threaded replies self._last_bot_message: dict[str, int] = {} + self._connection_repo = config.get("connection_repo") async def start(self) -> None: if self._running: @@ -233,6 +235,54 @@ class TelegramChannel(Channel): return True return user_id in self._allowed_users + @staticmethod + def _telegram_display_name(user) -> str: + full_name = getattr(user, "full_name", None) + if isinstance(full_name, str) and full_name: + return full_name + username = getattr(user, "username", None) + if isinstance(username, str) and username: + return username + return str(getattr(user, "id", "")) + + async def _bind_connection_from_start_token(self, update, state_token: str) -> bool: + if self._connection_repo is None or not state_token: + return False + + state = await self._connection_repo.consume_oauth_state(provider="telegram", state=state_token) + if state is None: + await update.message.reply_text("Telegram connection link is invalid or expired.") + return True + + owner_user_id = state["owner_user_id"] + user_id = str(update.effective_user.id) + chat_id = str(update.effective_chat.id) + connection = await self._connection_repo.upsert_connection( + owner_user_id=owner_user_id, + provider="telegram", + external_account_id=user_id, + external_account_name=self._telegram_display_name(update.effective_user), + workspace_id=chat_id, + workspace_name=None, + metadata={ + "chat_id": chat_id, + "chat_type": update.effective_chat.type, + "telegram_username": getattr(update.effective_user, "username", None), + }, + status="connected", + ) + logger.info("[Telegram] bound chat=%s user=%s to DeerFlow user=%s connection=%s", chat_id, user_id, owner_user_id, connection["id"]) + await update.message.reply_text("Telegram connected to DeerFlow.") + return True + + async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage: + return await attach_connection_identity( + inbound, + repo=self._connection_repo, + provider="telegram", + workspace_id=inbound.chat_id, + ) + def _get_bot_username(self, context) -> str | None: bot = getattr(context, "bot", None) username = getattr(bot, "username", None) @@ -264,6 +314,11 @@ class TelegramChannel(Channel): """Handle /start command.""" if not self._check_user(update.effective_user.id): return + args = getattr(context, "args", []) if context is not None else [] + if args: + handled = await self._bind_connection_from_start_token(update, str(args[0])) + if handled: + return await update.message.reply_text("Welcome to DeerFlow! Send me a message to start a conversation.\nType /help for available commands.") async def _process_incoming_with_reply(self, chat_id: str, msg_id: int, inbound: InboundMessage) -> None: @@ -299,6 +354,7 @@ class TelegramChannel(Channel): thread_ts=msg_id, ) inbound.topic_id = topic_id + inbound = await self._attach_connection_identity(inbound) if self._main_loop and self._main_loop.is_running(): fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop) @@ -341,6 +397,7 @@ class TelegramChannel(Channel): thread_ts=msg_id, ) inbound.topic_id = topic_id + inbound = await self._attach_connection_identity(inbound) if self._main_loop and self._main_loop.is_running(): fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop) diff --git a/backend/app/channels/wechat.py b/backend/app/channels/wechat.py index 9a9ddf1a6..a605a8d2f 100644 --- a/backend/app/channels/wechat.py +++ b/backend/app/channels/wechat.py @@ -22,8 +22,9 @@ from cryptography.hazmat.primitives import padding from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from app.channels.base import Channel -from app.channels.commands import is_known_channel_command -from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment +from app.channels.commands import extract_connect_code, is_known_channel_command +from app.channels.connection_identity import attach_connection_identity +from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment logger = logging.getLogger(__name__) @@ -253,6 +254,7 @@ class WechatChannel(Channel): self._state_dir = self._resolve_state_dir(config.get("state_dir")) self._cursor_path = self._state_dir / "wechat-getupdates.json" if self._state_dir else None self._auth_path = self._state_dir / "wechat-auth.json" if self._state_dir else None + self._connection_repo = config.get("connection_repo") self._load_state() async def start(self) -> None: @@ -617,6 +619,16 @@ class WechatChannel(Channel): if thread_ts: self._context_tokens_by_thread[thread_ts] = context_token + connect_code = extract_connect_code(text) + if connect_code and self._connection_repo is not None: + handled = await self._bind_connection_from_connect_code( + chat_id=chat_id, + context_token=context_token, + code=connect_code, + ) + if handled: + return + inbound = self._make_inbound( chat_id=chat_id, user_id=chat_id, @@ -632,8 +644,54 @@ class WechatChannel(Channel): }, ) inbound.topic_id = None + inbound = await self._attach_connection_identity(inbound) await self.bus.publish_inbound(inbound) + async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage: + return await attach_connection_identity( + inbound, + repo=self._connection_repo, + provider="wechat", + workspace_id=inbound.chat_id, + ) + + async def _bind_connection_from_connect_code(self, *, chat_id: str, context_token: str, code: str) -> bool: + if self._connection_repo is None or not code: + return False + + state = await self._connection_repo.consume_oauth_state(provider="wechat", state=code) + if state is None: + await self._send_connection_reply(chat_id, context_token, "WeChat connection code is invalid or expired.") + return True + + if not chat_id: + await self._send_connection_reply(chat_id, context_token, "WeChat connection could not be completed from this message.") + return True + + await self._connection_repo.upsert_connection( + owner_user_id=state["owner_user_id"], + provider="wechat", + external_account_id=chat_id, + workspace_id=chat_id, + metadata={ + "context_token": context_token, + }, + status="connected", + ) + await self._send_connection_reply(chat_id, context_token, "WeChat connected to DeerFlow.") + return True + + async def _send_connection_reply(self, chat_id: str, context_token: str, text: str) -> None: + if not context_token: + return + await self._send_text_message( + chat_id=chat_id, + context_token=context_token, + text=text, + client_id_prefix="deerflow-connect", + max_retries=1, + ) + async def _ensure_authenticated(self) -> bool: async with self._auth_lock: if self._bot_token: diff --git a/backend/app/channels/wecom.py b/backend/app/channels/wecom.py index 33d3cf1bb..badb0b525 100644 --- a/backend/app/channels/wecom.py +++ b/backend/app/channels/wecom.py @@ -8,8 +8,10 @@ from collections.abc import Awaitable, Callable from typing import Any, cast from app.channels.base import Channel -from app.channels.commands import is_known_channel_command +from app.channels.commands import extract_connect_code, is_known_channel_command +from app.channels.connection_identity import attach_connection_identity from app.channels.message_bus import ( + InboundMessage, InboundMessageType, MessageBus, OutboundMessage, @@ -29,6 +31,7 @@ class WeComChannel(Channel): self._ws_frames: dict[str, dict[str, Any]] = {} self._ws_stream_ids: dict[str, str] = {} self._working_message = "Working on it..." + self._connection_repo = config.get("connection_repo") @property def supports_streaming(self) -> bool: @@ -271,6 +274,16 @@ class WeComChannel(Channel): user_id = (body.get("from") or {}).get("userid") + connect_code = extract_connect_code(text) + if connect_code and self._connection_repo is not None: + handled = await self._bind_connection_from_connect_code( + frame=frame, + user_id=str(user_id or ""), + code=connect_code, + ) + if handled: + return + inbound_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT inbound = self._make_inbound( chat_id=user_id, # keep user's conversation in memory @@ -292,8 +305,52 @@ class WeComChannel(Channel): except Exception: pass + inbound = await self._attach_connection_identity(inbound) await self.bus.publish_inbound(inbound) + async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage: + return await attach_connection_identity( + inbound, + repo=self._connection_repo, + provider="wecom", + workspace_id=str(inbound.metadata.get("aibotid") or "") or None, + fallback_without_workspace=True, + ) + + async def _bind_connection_from_connect_code(self, *, frame: dict[str, Any], user_id: str, code: str) -> bool: + if self._connection_repo is None or not code: + return False + + state = await self._connection_repo.consume_oauth_state(provider="wecom", state=code) + if state is None: + await self._send_connection_reply(frame, "WeCom connection code is invalid or expired.") + return True + + if not user_id: + await self._send_connection_reply(frame, "WeCom connection could not be completed from this message.") + return True + + body = frame.get("body", {}) or {} + workspace_id = str(body.get("aibotid") or "") or None + await self._connection_repo.upsert_connection( + owner_user_id=state["owner_user_id"], + provider="wecom", + external_account_id=user_id, + workspace_id=workspace_id, + metadata={ + "aibotid": workspace_id, + "chattype": body.get("chattype"), + }, + status="connected", + ) + await self._send_connection_reply(frame, "WeCom connected to DeerFlow.") + return True + + async def _send_connection_reply(self, frame: dict[str, Any], text: str) -> None: + if not self._ws_client: + return + await self._ws_client.reply(frame, {"msgtype": "text", "text": {"content": text}}) + async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None: if not self._ws_client: return diff --git a/backend/app/gateway/app.py b/backend/app/gateway/app.py index e0aebc180..7e080a587 100644 --- a/backend/app/gateway/app.py +++ b/backend/app/gateway/app.py @@ -16,6 +16,7 @@ from app.gateway.routers import ( artifacts, assistants_compat, auth, + channel_connections, channels, feedback, mcp, @@ -384,6 +385,9 @@ This gateway provides runtime endpoints for agent runs plus custom endpoints for # Suggestions API is mounted at /api/threads/{thread_id}/suggestions app.include_router(suggestions.router) + # User-facing IM channel connection API is mounted at /api/channels + app.include_router(channel_connections.router) + # Channels API is mounted at /api/channels app.include_router(channels.router) diff --git a/backend/app/gateway/auth_disabled.py b/backend/app/gateway/auth_disabled.py index 396de7129..ef8e6e78c 100644 --- a/backend/app/gateway/auth_disabled.py +++ b/backend/app/gateway/auth_disabled.py @@ -6,9 +6,11 @@ import logging import os from types import SimpleNamespace +from deerflow.runtime.user_context import DEFAULT_USER_ID + AUTH_DISABLED_ENV_VAR = "DEER_FLOW_AUTH_DISABLED" -AUTH_DISABLED_USER_ID = "e2e-user" -AUTH_DISABLED_USER_EMAIL = "e2e@test.local" +AUTH_DISABLED_USER_ID = DEFAULT_USER_ID +AUTH_DISABLED_USER_EMAIL = "default@test.local" AUTH_SOURCE_SESSION = "session" AUTH_SOURCE_INTERNAL = "internal" diff --git a/backend/app/gateway/authz.py b/backend/app/gateway/authz.py index c7cf63858..aa82df076 100644 --- a/backend/app/gateway/authz.py +++ b/backend/app/gateway/authz.py @@ -276,6 +276,8 @@ def require_permission( # strict-deny rather than strict-allow — only an *existing* # row with a *different* user_id triggers 404. if owner_check: + from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE + thread_id = kwargs.get("thread_id") if thread_id is None: raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter") @@ -288,6 +290,22 @@ def require_permission( str(auth.user.id), require_existing=require_existing, ) + if not allowed and getattr(auth.user, "system_role", None) == INTERNAL_SYSTEM_ROLE: + # Trusted internal callers (channel workers) also act for + # the connection owner carried in X-DeerFlow-Owner-User-Id. + # Scope the check to that owner instead of bypassing it; a + # leaked internal token must not grant cross-user thread + # access. The header is honored only after ``auth`` proved + # the caller holds the internal token (mirrors + # get_trusted_internal_owner_user_id, which keys off the + # middleware-stamped ``request.state.user``). + header_owner = (request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME) or "").strip() + if header_owner: + allowed = await thread_store.check_access( + thread_id, + header_owner, + require_existing=require_existing, + ) if not allowed: raise HTTPException( status_code=404, diff --git a/backend/app/gateway/internal_auth.py b/backend/app/gateway/internal_auth.py index 3a00a9662..400e997bb 100644 --- a/backend/app/gateway/internal_auth.py +++ b/backend/app/gateway/internal_auth.py @@ -5,10 +5,12 @@ from __future__ import annotations import os import secrets from types import SimpleNamespace +from typing import Any from deerflow.runtime.user_context import DEFAULT_USER_ID INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token" +INTERNAL_OWNER_USER_ID_HEADER_NAME = "X-DeerFlow-Owner-User-Id" INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN" INTERNAL_SYSTEM_ROLE = "internal" @@ -23,9 +25,12 @@ def _load_internal_auth_token() -> str: _INTERNAL_AUTH_TOKEN = _load_internal_auth_token() -def create_internal_auth_headers() -> dict[str, str]: +def create_internal_auth_headers(*, owner_user_id: str | None = None) -> dict[str, str]: """Return headers that authenticate trusted Gateway internal calls.""" - return {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN} + headers = {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN} + if owner_user_id: + headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] = owner_user_id + return headers def is_valid_internal_auth_token(token: str | None) -> bool: @@ -36,3 +41,21 @@ def is_valid_internal_auth_token(token: str | None) -> bool: def get_internal_user(): """Return the synthetic user used for trusted internal channel calls.""" return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE) + + +def get_trusted_internal_owner_user_id(request: Any) -> str | None: + """Return the owner override for a trusted internal request, if present. + + The header is ignored for normal browser/API callers. It is only honored + after ``AuthMiddleware`` has validated the internal auth token and stamped + the synthetic internal user onto ``request.state.user``. + """ + user = getattr(getattr(request, "state", None), "user", None) + if getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE: + return None + + owner_user_id = request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME) + if not owner_user_id: + return None + owner_user_id = owner_user_id.strip() + return owner_user_id or None diff --git a/backend/app/gateway/routers/channel_connections.py b/backend/app/gateway/routers/channel_connections.py new file mode 100644 index 000000000..1c7133078 --- /dev/null +++ b/backend/app/gateway/routers/channel_connections.py @@ -0,0 +1,670 @@ +"""Browser-facing APIs for user-owned IM channel bindings.""" + +from __future__ import annotations + +import asyncio +import logging +import secrets +from datetime import UTC, datetime, timedelta +from typing import Any + +from fastapi import APIRouter, HTTPException, Request, Response +from pydantic import BaseModel, Field + +from app.channels.runtime_config_store import ( + ChannelRuntimeConfigStore, + apply_runtime_connection_config, + merge_runtime_channel_configs, +) +from deerflow.config.channel_connections_config import ChannelConnectionsConfig +from deerflow.persistence.channel_connections import ChannelConnectionRepository +from deerflow.persistence.engine import get_session_factory + +router = APIRouter(prefix="/api/channels", tags=["channel-connections"]) +logger = logging.getLogger(__name__) + +_STATE_TTL_SECONDS = 600 +_MASKED_CREDENTIAL_VALUE = "********" + + +class ChannelCredentialFieldResponse(BaseModel): + name: str + label: str + type: str = "text" + required: bool = True + + +class ChannelProviderResponse(BaseModel): + provider: str + display_name: str + enabled: bool + configured: bool + connectable: bool + unavailable_reason: str | None = None + auth_mode: str + connection_status: str + credential_fields: list[ChannelCredentialFieldResponse] = Field(default_factory=list) + credential_values: dict[str, str] = Field(default_factory=dict) + + +class ChannelProvidersResponse(BaseModel): + enabled: bool + providers: list[ChannelProviderResponse] + + +class ChannelConnectionResponse(BaseModel): + id: str + provider: str + status: str + external_account_id: str | None = None + external_account_name: str | None = None + workspace_id: str | None = None + workspace_name: str | None = None + scopes: list[str] = Field(default_factory=list) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class ChannelConnectionsResponse(BaseModel): + connections: list[ChannelConnectionResponse] + + +class ChannelConnectResponse(BaseModel): + provider: str + mode: str + url: str | None = None + code: str + instruction: str + expires_in: int + + +class ChannelRuntimeConfigRequest(BaseModel): + values: dict[str, str] = Field(default_factory=dict) + + +_PROVIDER_META: dict[str, dict[str, str]] = { + "telegram": {"display_name": "Telegram", "auth_mode": "deep_link"}, + "slack": {"display_name": "Slack", "auth_mode": "binding_code"}, + "discord": {"display_name": "Discord", "auth_mode": "binding_code"}, + "feishu": {"display_name": "Feishu", "auth_mode": "binding_code"}, + "dingtalk": {"display_name": "DingTalk", "auth_mode": "binding_code"}, + "wechat": {"display_name": "WeChat", "auth_mode": "binding_code"}, + "wecom": {"display_name": "WeCom", "auth_mode": "binding_code"}, +} + +_CREDENTIAL_FIELDS: dict[str, tuple[dict[str, str], ...]] = { + "telegram": ( + {"name": "bot_token", "label": "Bot token", "type": "password"}, + {"name": "bot_username", "label": "Bot username", "type": "text"}, + ), + "slack": ( + {"name": "bot_token", "label": "Bot token", "type": "password"}, + {"name": "app_token", "label": "App token", "type": "password"}, + ), + "discord": ({"name": "bot_token", "label": "Bot token", "type": "password"},), + "feishu": ( + {"name": "app_id", "label": "App ID", "type": "text"}, + {"name": "app_secret", "label": "App secret", "type": "password"}, + ), + "dingtalk": ( + {"name": "client_id", "label": "Client ID", "type": "text"}, + {"name": "client_secret", "label": "Client secret", "type": "password"}, + ), + "wechat": ({"name": "bot_token", "label": "Bot token", "type": "password"},), + "wecom": ( + {"name": "bot_id", "label": "Bot ID", "type": "text"}, + {"name": "bot_secret", "label": "Bot secret", "type": "password"}, + ), +} + +_RUNTIME_REQUIREMENTS: dict[str, tuple[str, ...]] = { + "telegram": ("bot_token",), + "slack": ("bot_token", "app_token"), + "discord": ("bot_token",), + "feishu": ("app_id", "app_secret"), + "dingtalk": ("client_id", "client_secret"), + "wechat": ("bot_token",), + "wecom": ("bot_id", "bot_secret"), +} + + +def _get_user_id(request: Request) -> str: + user = getattr(request.state, "user", None) + if user is None: + raise HTTPException(status_code=401, detail="Authentication required") + return str(user.id) + + +async def _require_admin_user(request: Request) -> None: + """Require an admin caller for instance-wide channel runtime mutations. + + Runtime credentials and the channel workers they start/stop are shared by + every user of the deployment, so only admins may change them (same model + as the MCP config API). Auth-disabled local mode uses a synthetic admin + user and is unaffected. + """ + user = getattr(request.state, "user", None) + if user is None: + from app.gateway.deps import get_current_user_from_request + + user = await get_current_user_from_request(request) + + if getattr(user, "system_role", None) != "admin": + raise HTTPException(status_code=403, detail="Admin privileges required to manage channel runtime credentials.") + + +def _get_app_config(): + from deerflow.config.app_config import get_app_config + + return get_app_config() + + +async def _get_runtime_config_store(request: Request) -> ChannelRuntimeConfigStore: + store = getattr(request.app.state, "channel_runtime_config_store", None) + if isinstance(store, ChannelRuntimeConfigStore): + return store + # Constructing the store reads its JSON file from disk; keep it off the + # event loop. + store = await asyncio.to_thread(ChannelRuntimeConfigStore) + request.app.state.channel_runtime_config_store = store + return store + + +async def _get_channel_connections_config(request: Request) -> ChannelConnectionsConfig: + config = getattr(request.app.state, "channel_connections_config", None) + if not isinstance(config, ChannelConnectionsConfig): + config = _get_app_config().channel_connections + config = apply_runtime_connection_config(config, store=await _get_runtime_config_store(request)) + request.app.state.channel_connections_config = config + return config + + +async def _get_channels_config(request: Request) -> dict[str, Any]: + state_config = getattr(request.app.state, "channels_config", None) + if isinstance(state_config, dict): + return state_config + + result = await _load_channels_config(request, await _get_channel_connections_config(request)) + request.app.state.channels_config = result + return result + + +async def _load_channels_config(request: Request, config: ChannelConnectionsConfig) -> dict[str, Any]: + app_config = _get_app_config() + extra = app_config.model_extra or {} + channels_config = extra.get("channels") + result = dict(channels_config) if isinstance(channels_config, dict) else {} + merge_runtime_channel_configs( + result, + config, + store=await _get_runtime_config_store(request), + ) + return result + + +def _get_repository(request: Request, config: ChannelConnectionsConfig) -> ChannelConnectionRepository: + repo = getattr(request.app.state, "channel_connection_repo", None) + if isinstance(repo, ChannelConnectionRepository): + return repo + + sf = get_session_factory() + if sf is None: + raise HTTPException(status_code=503, detail="Channel connection persistence is not available") + + repo = ChannelConnectionRepository(sf) + request.app.state.channel_connection_repo = repo + return repo + + +def _provider_config(config: ChannelConnectionsConfig, provider: str): + provider_config = getattr(config, provider, None) + if provider_config is None: + raise HTTPException(status_code=404, detail="Unknown channel provider") + return provider_config + + +def _runtime_channel_configured(provider: str, channels_config: dict[str, Any]) -> bool: + runtime_config = channels_config.get(provider) + if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False): + return False + return all(str(runtime_config.get(key) or "").strip() for key in _RUNTIME_REQUIREMENTS[provider]) + + +def _runtime_unavailable_reason(provider: str) -> str: + meta = _PROVIDER_META.get(provider) + display_name = meta["display_name"] if meta else provider + return f"Enter the required {display_name} credentials to connect this channel." + + +def _runtime_not_running_reason(provider: str) -> str: + meta = _PROVIDER_META.get(provider) + display_name = meta["display_name"] if meta else provider + return f"{display_name} channel is configured but is not running. Check the credentials and service logs." + + +def _runtime_channel_running(provider: str) -> bool | None: + try: + from app.channels.service import get_channel_service + except Exception: + logger.debug("Unable to inspect channel service status", exc_info=True) + return None + + service = get_channel_service() + if service is None: + return None + try: + status = service.get_status() + except Exception: + logger.debug("Unable to read channel service status", exc_info=True) + return None + + if not status.get("service_running"): + return False + channel_status = status.get("channels", {}).get(provider) + if not isinstance(channel_status, dict): + return None + return bool(channel_status.get("running")) + + +async def _ensure_runtime_channel_ready_if_available( + provider: str, + channels_config: dict[str, Any], +) -> bool | None: + runtime_config = channels_config.get(provider) + if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False): + return None + + try: + from app.channels.service import get_channel_service + except Exception: + logger.debug("Unable to import channel service for readiness reconciliation", exc_info=True) + return None + + service = get_channel_service() + if service is None: + return None + + ensure_channel_ready = getattr(service, "ensure_channel_ready", None) + if ensure_channel_ready is None: + return None + + try: + return await ensure_channel_ready(provider, runtime_config) + except Exception: + logger.exception("Failed to reconcile runtime channel readiness") + return False + + +def _provider_unavailable_reason( + config: ChannelConnectionsConfig, + channels_config: dict[str, Any], + provider: str, +) -> str | None: + provider_config = _provider_config(config, provider) + if not provider_config.enabled: + return None + if not provider_config.configured: + return _runtime_unavailable_reason(provider) + if not _runtime_channel_configured(provider, channels_config): + return _runtime_unavailable_reason(provider) + if _runtime_channel_running(provider) is False: + return _runtime_not_running_reason(provider) + return None + + +def _provider_status( + config: ChannelConnectionsConfig, + channels_config: dict[str, Any], + provider: str, +) -> tuple[dict[str, bool], str | None]: + declared = config.provider_status(provider) + unavailable_reason = _provider_unavailable_reason(config, channels_config, provider) + configured = declared["configured"] and _runtime_channel_configured(provider, channels_config) + return {"enabled": declared["enabled"], "configured": configured}, unavailable_reason + + +def _new_binding_code() -> str: + return secrets.token_urlsafe(16) + + +async def _create_state( + repo: ChannelConnectionRepository, + *, + owner_user_id: str, + provider: str, +) -> str: + state = _new_binding_code() + await repo.create_oauth_state( + owner_user_id=owner_user_id, + provider=provider, + state=state, + expires_at=datetime.now(UTC) + timedelta(seconds=_STATE_TTL_SECONDS), + ) + return state + + +def _connect_instruction(provider: str, code: str) -> str: + if provider == "telegram": + return f"Send /start {code} to the DeerFlow Telegram bot." + meta = _PROVIDER_META.get(provider) + if meta is None: + raise HTTPException(status_code=404, detail="Unknown channel provider") + return f"Send /connect {code} to the DeerFlow {meta['display_name']} bot." + + +def _connect_url(config: ChannelConnectionsConfig, provider: str, code: str) -> str | None: + if provider == "telegram": + provider_config = _provider_config(config, provider) + return f"https://t.me/{provider_config.bot_username}?start={code}" + if _PROVIDER_META.get(provider, {}).get("auth_mode") == "binding_code": + return None + raise HTTPException(status_code=404, detail="Unknown channel provider") + + +def _connection_updated_at(connection: dict[str, Any]) -> datetime: + value = connection.get("updated_at") + if isinstance(value, datetime): + return value if value.tzinfo is not None else value.replace(tzinfo=UTC) + if isinstance(value, str) and value: + try: + return datetime.fromisoformat(value.replace("Z", "+00:00")) + except ValueError: + pass + return datetime.min.replace(tzinfo=UTC) + + +def _newest_connection_by_provider(connections: list[dict[str, Any]]) -> dict[str, dict[str, Any]]: + by_provider: dict[str, dict[str, Any]] = {} + for item in connections: + existing = by_provider.get(item["provider"]) + if existing is None or _connection_updated_at(item) > _connection_updated_at(existing): + by_provider[item["provider"]] = item + return by_provider + + +def _credential_fields(provider: str) -> list[ChannelCredentialFieldResponse]: + fields = _CREDENTIAL_FIELDS.get(provider) + if fields is None: + raise HTTPException(status_code=404, detail="Unknown channel provider") + return [ChannelCredentialFieldResponse(**field) for field in fields] + + +def _credential_values(provider: str, channels_config: dict[str, Any]) -> dict[str, str]: + runtime_config = channels_config.get(provider) + if not isinstance(runtime_config, dict): + return {} + + values: dict[str, str] = {} + for field in _credential_fields(provider): + value = str(runtime_config.get(field.name) or "").strip() + if not value: + continue + values[field.name] = _MASKED_CREDENTIAL_VALUE if field.type == "password" else value + return values + + +def _provider_response( + config: ChannelConnectionsConfig, + channels_config: dict[str, Any], + provider: str, + meta: dict[str, str], + connection: dict[str, Any] | None = None, +) -> ChannelProviderResponse: + from app.gateway.auth_disabled import is_auth_disabled + + status, unavailable_reason = _provider_status(config, channels_config, provider) + if connection: + connection_status = connection["status"] + elif is_auth_disabled() and status["configured"] and unavailable_reason is None: + # Auth-disabled local mode routes every channel message to the default + # user, so a configured running channel needs no per-user binding. + connection_status = "connected" + else: + connection_status = "not_connected" + credential_values = _credential_values(provider, channels_config) + if provider == "telegram" and not credential_values.get("bot_username"): + bot_username = str(_provider_config(config, provider).bot_username or "").strip() + if bot_username: + credential_values["bot_username"] = bot_username + return ChannelProviderResponse( + provider=provider, + display_name=meta["display_name"], + enabled=status["enabled"], + configured=status["configured"], + connectable=status["enabled"] and status["configured"] and unavailable_reason is None, + unavailable_reason=unavailable_reason, + auth_mode=meta["auth_mode"], + connection_status=connection_status, + credential_fields=_credential_fields(provider), + credential_values=credential_values, + ) + + +def _required_runtime_values( + provider: str, + values: dict[str, str], + existing_config: dict[str, Any] | None = None, +) -> dict[str, str]: + fields = _credential_fields(provider) + cleaned: dict[str, str] = {} + missing: list[str] = [] + existing_config = existing_config or {} + for field in fields: + raw_value = values.get(field.name, "") + if field.type == "password" and raw_value == _MASKED_CREDENTIAL_VALUE: + existing_value = str(existing_config.get(field.name) or "").strip() + if existing_value: + cleaned[field.name] = existing_value + continue + value = raw_value.strip() if isinstance(raw_value, str) else str(raw_value or "").strip() + if field.required and not value: + missing.append(field.label) + cleaned[field.name] = value + if missing: + raise HTTPException(status_code=400, detail=f"Missing required channel configuration: {', '.join(missing)}") + return cleaned + + +async def _restart_runtime_channel_if_available(provider: str, runtime_config: dict[str, Any]) -> bool | None: + try: + from app.channels.service import get_channel_service + except Exception: + logger.exception("Failed to import channel service while configuring a runtime channel") + return None + + service = get_channel_service() + if service is None: + return None + return await service.configure_channel(provider, runtime_config) + + +async def _sync_runtime_channel_after_removal(provider: str, channels_config: dict[str, Any]) -> bool | None: + try: + from app.channels.service import get_channel_service + except Exception: + logger.exception("Failed to import channel service while disconnecting a runtime channel") + return None + + service = get_channel_service() + if service is None: + return None + + runtime_config = channels_config.get(provider) + if isinstance(runtime_config, dict) and runtime_config.get("enabled", False): + return await service.configure_channel(provider, runtime_config) + return await service.remove_channel(provider) + + +@router.get("/providers", response_model=ChannelProvidersResponse) +async def get_channel_providers(request: Request) -> ChannelProvidersResponse: + config = await _get_channel_connections_config(request) + channels_config = await _get_channels_config(request) + repo = None + if config.enabled: + try: + repo = _get_repository(request, config) + except HTTPException as exc: + if exc.status_code != 503: + raise + owner_user_id = _get_user_id(request) + connections = await repo.list_connections(owner_user_id) if repo is not None else [] + by_provider = _newest_connection_by_provider(connections) + + enabled_providers = [provider for provider in _PROVIDER_META if config.provider_status(provider)["enabled"]] + # Readiness reconciliation is independent per provider; run it + # concurrently so one slow channel restart does not serialize the + # whole /providers response. + await asyncio.gather( + *(_ensure_runtime_channel_ready_if_available(provider, channels_config) for provider in enabled_providers if _runtime_channel_configured(provider, channels_config)), + ) + + providers: list[ChannelProviderResponse] = [] + for provider in enabled_providers: + connection = by_provider.get(provider) + providers.append(_provider_response(config, channels_config, provider, _PROVIDER_META[provider], connection)) + return ChannelProvidersResponse(enabled=config.enabled, providers=providers) + + +@router.get("/connections", response_model=ChannelConnectionsResponse) +async def get_channel_connections(request: Request) -> ChannelConnectionsResponse: + config = await _get_channel_connections_config(request) + if not config.enabled: + return ChannelConnectionsResponse(connections=[]) + repo = _get_repository(request, config) + rows = await repo.list_connections(_get_user_id(request)) + return ChannelConnectionsResponse(connections=[ChannelConnectionResponse(**row) for row in rows]) + + +@router.delete("/connections/{connection_id}", status_code=204) +async def disconnect_channel_connection(connection_id: str, request: Request) -> Response: + config = await _get_channel_connections_config(request) + if not config.enabled: + raise HTTPException(status_code=400, detail="Channel connections are disabled") + + repo = _get_repository(request, config) + disconnected = await repo.disconnect_connection( + connection_id=connection_id, + owner_user_id=_get_user_id(request), + ) + if not disconnected: + raise HTTPException(status_code=404, detail="Channel connection not found") + return Response(status_code=204) + + +@router.delete("/{provider}/runtime-config", response_model=ChannelProviderResponse) +async def disconnect_channel_provider_runtime(provider: str, request: Request) -> ChannelProviderResponse: + await _require_admin_user(request) + config = await _get_channel_connections_config(request) + if not config.enabled: + raise HTTPException(status_code=400, detail="Channel connections are disabled") + + provider_config = _provider_config(config, provider) + if not provider_config.enabled: + raise HTTPException(status_code=400, detail="Channel provider is not enabled") + + owner_user_id = _get_user_id(request) + try: + repo = _get_repository(request, config) + except HTTPException as exc: + if exc.status_code != 503: + raise + repo = None + + if repo is not None: + for connection in await repo.list_connections(owner_user_id): + if connection["provider"] == provider and connection["status"] != "revoked": + await repo.disconnect_connection( + connection_id=connection["id"], + owner_user_id=owner_user_id, + ) + + store = await _get_runtime_config_store(request) + await asyncio.to_thread(store.set_provider_disconnected, provider) + channels_config = await _load_channels_config(request, config) + request.app.state.channels_config = channels_config + + stopped = await _sync_runtime_channel_after_removal(provider, channels_config) + if stopped is False: + display_name = _PROVIDER_META[provider]["display_name"] + raise HTTPException(status_code=400, detail=f"Failed to stop {display_name} channel. Try again.") + + return _provider_response(config, channels_config, provider, _PROVIDER_META[provider]) + + +@router.post("/{provider}/connect", response_model=ChannelConnectResponse) +async def connect_channel_provider(provider: str, request: Request) -> ChannelConnectResponse: + config = await _get_channel_connections_config(request) + channels_config = await _get_channels_config(request) + if not config.enabled: + raise HTTPException(status_code=400, detail="Channel connections are disabled") + + provider_config = _provider_config(config, provider) + if provider_config.enabled and _runtime_channel_configured(provider, channels_config): + await _ensure_runtime_channel_ready_if_available(provider, channels_config) + + status, unavailable_reason = _provider_status(config, channels_config, provider) + if not status["enabled"]: + raise HTTPException(status_code=400, detail="Channel provider is not enabled") + if unavailable_reason: + raise HTTPException(status_code=400, detail=unavailable_reason) + if not status["configured"]: + raise HTTPException(status_code=400, detail="Channel provider is not configured") + + repo = _get_repository(request, config) + code = await _create_state( + repo, + owner_user_id=_get_user_id(request), + provider=provider, + ) + return ChannelConnectResponse( + provider=provider, + mode=_PROVIDER_META[provider]["auth_mode"], + url=_connect_url(config, provider, code), + code=code, + instruction=_connect_instruction(provider, code), + expires_in=_STATE_TTL_SECONDS, + ) + + +@router.post("/{provider}/runtime-config", response_model=ChannelProviderResponse) +async def configure_channel_provider_runtime( + provider: str, + body: ChannelRuntimeConfigRequest, + request: Request, +) -> ChannelProviderResponse: + await _require_admin_user(request) + config = await _get_channel_connections_config(request) + if not config.enabled: + raise HTTPException(status_code=400, detail="Channel connections are disabled") + + provider_config = _provider_config(config, provider) + if not provider_config.enabled: + raise HTTPException(status_code=400, detail="Channel provider is not enabled") + + channels_config = await _get_channels_config(request) + existing = channels_config.get(provider) + runtime_config = dict(existing) if isinstance(existing, dict) else {} + values = _required_runtime_values(provider, body.values, runtime_config) + runtime_config["enabled"] = True + + for key in _RUNTIME_REQUIREMENTS[provider]: + runtime_config[key] = values[key] + + if provider == "telegram": + # The deep-link username is persisted with the runtime channel config + # (set_provider_config below) and applied to future requests via + # apply_runtime_connection_config; never mutate the config instance + # cached by get_app_config(). + runtime_config["bot_username"] = values["bot_username"] + + channels_config[provider] = runtime_config + request.app.state.channels_config = channels_config + + started = await _restart_runtime_channel_if_available(provider, runtime_config) + if started is False: + display_name = _PROVIDER_META[provider]["display_name"] + raise HTTPException(status_code=400, detail=f"Failed to start {display_name} channel. Check the values and try again.") + + store = await _get_runtime_config_store(request) + await asyncio.to_thread(store.set_provider_config, provider, runtime_config) + + return _provider_response(config, channels_config, provider, _PROVIDER_META[provider]) diff --git a/backend/app/gateway/routers/threads.py b/backend/app/gateway/routers/threads.py index fa8de61ff..fd6c05289 100644 --- a/backend/app/gateway/routers/threads.py +++ b/backend/app/gateway/routers/threads.py @@ -22,6 +22,7 @@ from pydantic import BaseModel, Field, field_validator from app.gateway.authz import require_permission from app.gateway.deps import get_checkpointer +from app.gateway.internal_auth import get_trusted_internal_owner_user_id from app.gateway.utils import sanitize_log_param from deerflow.config.paths import Paths, get_paths from deerflow.runtime import serialize_channel_values @@ -257,11 +258,19 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe thread_store = get_thread_store(request) thread_id = body.thread_id or str(uuid.uuid4()) now = now_iso() + thread_owner_user_id = get_trusted_internal_owner_user_id(request) + thread_owner_kwargs = {"user_id": thread_owner_user_id} if thread_owner_user_id else {} # ``body.metadata`` is already stripped of server-reserved keys by # ``ThreadCreateRequest._strip_reserved`` — see the model definition. # Idempotency: return existing record when already present - existing_record = await thread_store.get(thread_id) + existing_record = await thread_store.get(thread_id, **thread_owner_kwargs) + if existing_record is None and thread_owner_user_id: + unscoped_record = await thread_store.get(thread_id, user_id=None) + if unscoped_record is not None: + if unscoped_record.get("user_id") != thread_owner_user_id: + await thread_store.update_owner(thread_id, thread_owner_user_id, user_id=None) + existing_record = await thread_store.get(thread_id, **thread_owner_kwargs) if existing_record is not None: return ThreadResponse( thread_id=thread_id, @@ -276,6 +285,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe await thread_store.create( thread_id, assistant_id=getattr(body, "assistant_id", None), + **thread_owner_kwargs, metadata=body.metadata, ) except Exception: diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index e9b5e212a..2f22fd731 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -12,6 +12,7 @@ import json import logging import re from collections.abc import Mapping +from types import SimpleNamespace from typing import Any from fastapi import HTTPException, Request @@ -19,7 +20,7 @@ from langchain_core.messages import BaseMessage from langchain_core.messages.utils import convert_to_messages from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge -from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE +from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE, get_trusted_internal_owner_user_id from app.gateway.utils import sanitize_log_param from deerflow.config.app_config import get_app_config from deerflow.runtime import ( @@ -35,6 +36,7 @@ from deerflow.runtime import ( run_agent, ) from deerflow.runtime.runs.naming import resolve_root_run_name +from deerflow.runtime.user_context import reset_current_user, set_current_user logger = logging.getLogger(__name__) @@ -315,6 +317,7 @@ async def start_run( detail=f"Model {model_name!r} is not in the configured model allowlist", ) + owner_user_id = get_trusted_internal_owner_user_id(request) # Stateless run endpoints carry thread_id in the request *body*, so the # @require_permission(owner_check=True) decorator -- which resolves ownership # from the path param -- cannot protect them. Enforce thread ownership here, @@ -323,79 +326,99 @@ async def start_run( # temp threads) and NULL-owner rows (shared / pre-auth data) stay accessible # via check_access; only a thread already owned by another user is rejected # with 404, matching thread_runs.py's anti-enumeration behaviour. Internal - # channel runs act on behalf of IM users they do not own (see - # inject_authenticated_user_context), so the internal system role is exempt. + # channel runs act on behalf of the connection owner carried in + # X-DeerFlow-Owner-User-Id, so they are scoped to that owner instead of + # bypassing the check -- a leaked internal token must not grant cross-user + # thread access. user = getattr(request.state, "user", None) - if user is not None and getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE: - if not await run_ctx.thread_store.check_access(thread_id, str(user.id)): + if user is not None: + allowed = await run_ctx.thread_store.check_access(thread_id, str(user.id)) + if not allowed and owner_user_id and getattr(user, "system_role", None) == INTERNAL_SYSTEM_ROLE: + # Channel workers may also act for the connection owner named in + # the trusted header (e.g. claiming a legacy default-owned channel + # thread for its real owner). + allowed = await run_ctx.thread_store.check_access(thread_id, owner_user_id) + if not allowed: raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") + owner_context_token = set_current_user(SimpleNamespace(id=owner_user_id)) if owner_user_id else None try: - record = await run_mgr.create_or_reject( - thread_id, - body.assistant_id, - on_disconnect=disconnect, - metadata=body.metadata or {}, - kwargs={"input": body.input, "config": body.config}, - multitask_strategy=body.multitask_strategy, - model_name=model_name, - ) - except ConflictError as exc: - raise HTTPException(status_code=409, detail=str(exc)) from exc - except UnsupportedStrategyError as exc: - raise HTTPException(status_code=501, detail=str(exc)) from exc - - # Upsert thread metadata so the thread appears in /threads/search, - # even for threads that were never explicitly created via POST /threads - # (e.g. stateless runs). - try: - existing = await run_ctx.thread_store.get(thread_id) - if existing is None: - await run_ctx.thread_store.create( + try: + record = await run_mgr.create_or_reject( thread_id, - assistant_id=body.assistant_id, - metadata=body.metadata, + body.assistant_id, + on_disconnect=disconnect, + metadata=body.metadata or {}, + kwargs={"input": body.input, "config": body.config}, + multitask_strategy=body.multitask_strategy, + model_name=model_name, + user_id=owner_user_id, ) - else: - await run_ctx.thread_store.update_status(thread_id, "running") - except Exception: - logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id)) + except ConflictError as exc: + raise HTTPException(status_code=409, detail=str(exc)) from exc + except UnsupportedStrategyError as exc: + raise HTTPException(status_code=501, detail=str(exc)) from exc - agent_factory = resolve_agent_factory(body.assistant_id) - graph_input = normalize_input(body.input) - config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id) + # Upsert thread metadata so the thread appears in /threads/search, + # even for threads that were never explicitly created via POST /threads + # (e.g. stateless runs). + try: + existing = await run_ctx.thread_store.get(thread_id) + if existing is None and owner_user_id: + unscoped_existing = await run_ctx.thread_store.get(thread_id, user_id=None) + if unscoped_existing is not None: + if unscoped_existing.get("user_id") != owner_user_id: + await run_ctx.thread_store.update_owner(thread_id, owner_user_id, user_id=None) + existing = await run_ctx.thread_store.get(thread_id) + if existing is None: + await run_ctx.thread_store.create( + thread_id, + assistant_id=body.assistant_id, + metadata=body.metadata, + ) + else: + await run_ctx.thread_store.update_status(thread_id, "running") + except Exception: + logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id)) - # Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``. - # The ``context`` field is a custom extension for the langgraph-compat layer - # that carries agent configuration (model_name, thinking_enabled, etc.). - # Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored. - merge_run_context_overrides(config, getattr(body, "context", None)) - inject_authenticated_user_context(config, request) + agent_factory = resolve_agent_factory(body.assistant_id) + graph_input = normalize_input(body.input) + config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id) - stream_modes = normalize_stream_modes(body.stream_mode) + # Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``. + # The ``context`` field is a custom extension for the langgraph-compat layer + # that carries agent configuration (model_name, thinking_enabled, etc.). + # Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored. + merge_run_context_overrides(config, getattr(body, "context", None)) + inject_authenticated_user_context(config, request) - task = asyncio.create_task( - run_agent( - bridge, - run_mgr, - record, - ctx=run_ctx, - agent_factory=agent_factory, - graph_input=graph_input, - config=config, - stream_modes=stream_modes, - stream_subgraphs=body.stream_subgraphs, - interrupt_before=body.interrupt_before, - interrupt_after=body.interrupt_after, + stream_modes = normalize_stream_modes(body.stream_mode) + + task = asyncio.create_task( + run_agent( + bridge, + run_mgr, + record, + ctx=run_ctx, + agent_factory=agent_factory, + graph_input=graph_input, + config=config, + stream_modes=stream_modes, + stream_subgraphs=body.stream_subgraphs, + interrupt_before=body.interrupt_before, + interrupt_after=body.interrupt_after, + ) ) - ) - record.task = task + record.task = task - # Title sync is handled by worker.py's finally block which reads the - # title from the checkpoint and calls thread_store.update_display_name - # after the run completes. + # Title sync is handled by worker.py's finally block which reads the + # title from the checkpoint and calls thread_store.update_display_name + # after the run completes. - return record + return record + finally: + if owner_context_token is not None: + reset_current_user(owner_context_token) async def sse_consumer( diff --git a/backend/docs/IM_CHANNEL_CONNECTIONS.md b/backend/docs/IM_CHANNEL_CONNECTIONS.md new file mode 100644 index 000000000..996c83568 --- /dev/null +++ b/backend/docs/IM_CHANNEL_CONNECTIONS.md @@ -0,0 +1,122 @@ +# IM Channel Connections + +DeerFlow supports user-owned IM channel bindings for Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom. The feature reuses the existing `channels.*` runtime configuration, so it works in local and private deployments with the same outbound transports already supported by DeerFlow. + +No public IP, OAuth callback URL, or provider webhook is required in this implementation. + +## Configuration + +Configure the actual IM bots under the existing `channels` block: + +```yaml +channels: + telegram: + enabled: true + bot_token: $TELEGRAM_BOT_TOKEN + + slack: + enabled: true + bot_token: $SLACK_BOT_TOKEN + app_token: $SLACK_APP_TOKEN + + discord: + enabled: true + bot_token: $DISCORD_BOT_TOKEN + + feishu: + enabled: true + app_id: $FEISHU_APP_ID + app_secret: $FEISHU_APP_SECRET + + dingtalk: + enabled: true + client_id: $DINGTALK_CLIENT_ID + client_secret: $DINGTALK_CLIENT_SECRET + + wechat: + enabled: true + bot_token: $WECHAT_BOT_TOKEN + + wecom: + enabled: true + bot_id: $WECOM_BOT_ID + bot_secret: $WECOM_BOT_SECRET +``` + +Then enable user bindings in `channel_connections`: + +```yaml +channel_connections: + enabled: true + + telegram: + enabled: true + bot_username: $TELEGRAM_BOT_USERNAME + + slack: + enabled: true + + discord: + enabled: true + + feishu: + enabled: true + + dingtalk: + enabled: true + + wechat: + enabled: true + + wecom: + enabled: true +``` + +`channel_connections` does not duplicate provider secrets. It only controls the browser-facing connect UI and stores per-user binding records. Telegram needs `bot_username` only so the frontend can open a deep link. + +## Connect Flow + +Telegram: + +- The frontend creates a short one-time code. +- The Connect button opens `https://t.me/?start=`. +- The existing Telegram long-polling worker receives `/start ` and binds that Telegram chat/user to the current DeerFlow user. + +Slack: + +- The frontend creates a short one-time code. +- The UI shows `Send /connect to the DeerFlow Slack bot.` +- The existing Slack Socket Mode worker receives the message and binds the Slack user/team to the current DeerFlow user. + +Discord: + +- The frontend creates a short one-time code. +- The UI shows `Send /connect to the DeerFlow Discord bot.` +- The existing Discord Gateway worker receives the message and binds the Discord user/guild to the current DeerFlow user. + +Feishu/Lark, DingTalk, WeChat, and WeCom: + +- The frontend creates a short one-time code. +- The UI shows `Send /connect to the DeerFlow bot.` +- The already-running long-connection or polling worker receives the message and binds the platform user/workspace identity to the current DeerFlow user. + +Codes use 128 bits of randomness, expire after 10 minutes, and are single-use. + +## Runtime Model + +Connection records live in SQL tables under `deerflow.persistence.channel_connections`: + +- `channel_connections`: owner user, provider identity, workspace/guild/team, status, metadata. +- `channel_oauth_states`: one-time connect codes and Telegram deep-link state. +- `channel_conversations`: connection-scoped IM conversation to DeerFlow thread mapping. +- `channel_credentials`: reserved for future provider-token flows, not used by the local/private binding flow. + +Incoming messages that resolve to a connection carry `connection_id`, `owner_user_id`, and `workspace_id`. `ChannelManager` uses `owner_user_id` as the DeerFlow run user id and preserves the raw platform user id as `channel_user_id`. + +## Security Notes + +- Browser APIs remain authenticated and CSRF-protected. +- Connect codes are 128-bit random, short-lived, and single-use. +- Provider bot tokens remain in `channels.*` and are never returned to the browser. +- Stored per-connection credentials are encrypted. If stored credential material cannot be decrypted, DeerFlow treats it as unavailable instead of using corrupt secrets. +- This implementation does not add public provider callback or webhook routes. diff --git a/backend/packages/harness/deerflow/config/app_config.py b/backend/packages/harness/deerflow/config/app_config.py index 7352d0af7..5091b7d31 100644 --- a/backend/packages/harness/deerflow/config/app_config.py +++ b/backend/packages/harness/deerflow/config/app_config.py @@ -11,6 +11,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator from deerflow.config.acp_config import ACPAgentConfig, load_acp_config_from_dict from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict +from deerflow.config.channel_connections_config import ChannelConnectionsConfig from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict from deerflow.config.database_config import DatabaseConfig from deerflow.config.extensions_config import ExtensionsConfig @@ -116,6 +117,13 @@ class AppConfig(BaseModel): subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration") guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration") circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration") + channel_connections: ChannelConnectionsConfig = Field( + default_factory=ChannelConnectionsConfig, + description=format_field_description( + "channel_connections", + field_doc="User-facing IM channel connection configuration.", + ), + ) loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration") safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration") model_config = ConfigDict(extra="allow") diff --git a/backend/packages/harness/deerflow/config/channel_connections_config.py b/backend/packages/harness/deerflow/config/channel_connections_config.py new file mode 100644 index 000000000..4092d5863 --- /dev/null +++ b/backend/packages/harness/deerflow/config/channel_connections_config.py @@ -0,0 +1,61 @@ +"""Configuration for user-owned IM channel connections.""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class SlackChannelConnectionConfig(BaseModel): + enabled: bool = False + + @property + def configured(self) -> bool: + return True + + +class TelegramChannelConnectionConfig(BaseModel): + enabled: bool = False + bot_username: str = "" + + @property + def configured(self) -> bool: + return bool(self.bot_username) + + +class DiscordChannelConnectionConfig(BaseModel): + enabled: bool = False + + @property + def configured(self) -> bool: + return True + + +class BindingCodeChannelConnectionConfig(BaseModel): + enabled: bool = False + + @property + def configured(self) -> bool: + return True + + +class ChannelConnectionsConfig(BaseModel): + """Top-level config for browser-connectable IM channels.""" + + enabled: bool = False + slack: SlackChannelConnectionConfig = Field(default_factory=SlackChannelConnectionConfig) + telegram: TelegramChannelConnectionConfig = Field(default_factory=TelegramChannelConnectionConfig) + discord: DiscordChannelConnectionConfig = Field(default_factory=DiscordChannelConnectionConfig) + feishu: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig) + dingtalk: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig) + wechat: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig) + wecom: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig) + + def provider_status(self, provider: str) -> dict[str, bool]: + config = getattr(self, provider, None) + if config is None: + return {"enabled": False, "configured": False} + enabled = bool(config.enabled) + return { + "enabled": enabled, + "configured": enabled and bool(config.configured), + } diff --git a/backend/packages/harness/deerflow/config/paths.py b/backend/packages/harness/deerflow/config/paths.py index f01959657..343ef70a1 100644 --- a/backend/packages/harness/deerflow/config/paths.py +++ b/backend/packages/harness/deerflow/config/paths.py @@ -1,4 +1,5 @@ import hashlib +import logging import os import re import shutil @@ -14,6 +15,8 @@ _SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$") _UNSAFE_USER_ID_CHAR_RE = re.compile(r"[^A-Za-z0-9_\-]") _SAFE_USER_ID_DIGEST_HEX_LEN = 16 +logger = logging.getLogger(__name__) + def _default_local_base_dir() -> Path: """Return the caller project's writable DeerFlow state directory.""" @@ -47,7 +50,13 @@ def make_safe_user_id(raw: str) -> str: sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw) if sanitized == raw: return raw - digest = hashlib.sha1(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN] + digest = hashlib.sha256(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN] + return f"{sanitized}-{digest}" + + +def _legacy_safe_user_id(raw: str, sanitized: str) -> str: + """Bucket name produced by the previous (SHA-1) digest revision for ``raw``.""" + digest = hashlib.sha1(raw.encode("utf-8"), usedforsecurity=False).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN] return f"{sanitized}-{digest}" @@ -172,6 +181,32 @@ class Paths: """Directory for a specific user: `{base_dir}/users/{user_id}/`.""" return self.base_dir / "users" / _validate_user_id(user_id) + def prepare_user_dir_for_raw_id(self, raw_user_id: str) -> str: + """Return the safe user ID and migrate this ID's legacy unsafe-id bucket. + + A previous branch revision used SHA-1 for unsafe external user IDs. + New IDs use SHA-256; the legacy bucket name is recomputed from the same + raw ID, so only this user's own old bucket can ever be moved — a + different raw ID sharing the sanitized prefix produces a different + legacy digest and is never touched. + """ + safe_user_id = make_safe_user_id(raw_user_id) + sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw_user_id) + if safe_user_id == raw_user_id: + return safe_user_id + + users_dir = self.base_dir / "users" + target_dir = users_dir / safe_user_id + legacy_dir = users_dir / _legacy_safe_user_id(raw_user_id, sanitized) + try: + if target_dir.exists() or not legacy_dir.is_dir(): + return safe_user_id + legacy_dir.rename(target_dir) + logger.info("Migrated legacy unsafe-id user directory to the current digest format") + except OSError: + logger.exception("Failed to migrate legacy unsafe-id user directory") + return safe_user_id + def user_memory_file(self, user_id: str) -> Path: """Per-user memory file: `{base_dir}/users/{user_id}/memory.json`.""" return self.user_dir(user_id) / "memory.json" diff --git a/backend/packages/harness/deerflow/config/reload_boundary.py b/backend/packages/harness/deerflow/config/reload_boundary.py index d39502776..4c024fe52 100644 --- a/backend/packages/harness/deerflow/config/reload_boundary.py +++ b/backend/packages/harness/deerflow/config/reload_boundary.py @@ -56,6 +56,9 @@ STARTUP_ONLY_FIELDS: dict[str, str] = { # startup and the live channel clients are not rebuilt on # config.yaml edits. "channels": ("start_channel_service() is invoked once during startup; the live IM channel clients (Feishu, Slack, Telegram, DingTalk) are not rebuilt when channels.* changes."), + "channel_connections": ( + "start_channel_service() wires the connection repository and channel workers once at startup, and the channel-connections router caches the merged provider config on app.state; channel_connections.* edits need a restart." + ), } diff --git a/backend/packages/harness/deerflow/persistence/channel_connections/__init__.py b/backend/packages/harness/deerflow/persistence/channel_connections/__init__.py new file mode 100644 index 000000000..f3829b004 --- /dev/null +++ b/backend/packages/harness/deerflow/persistence/channel_connections/__init__.py @@ -0,0 +1,21 @@ +"""User-owned IM channel connection persistence.""" + +from deerflow.persistence.channel_connections.model import ( + ChannelConnectionRow, + ChannelConversationRow, + ChannelCredentialRow, + ChannelOAuthStateRow, +) +from deerflow.persistence.channel_connections.sql import ( + ChannelConnectionRepository, + ChannelCredentialCipher, +) + +__all__ = [ + "ChannelConnectionRepository", + "ChannelConnectionRow", + "ChannelConversationRow", + "ChannelCredentialCipher", + "ChannelCredentialRow", + "ChannelOAuthStateRow", +] diff --git a/backend/packages/harness/deerflow/persistence/channel_connections/model.py b/backend/packages/harness/deerflow/persistence/channel_connections/model.py new file mode 100644 index 000000000..94d1a5c4d --- /dev/null +++ b/backend/packages/harness/deerflow/persistence/channel_connections/model.py @@ -0,0 +1,111 @@ +"""ORM models for user-owned IM channel connections.""" + +from __future__ import annotations + +from datetime import UTC, datetime + +from sqlalchemy import JSON, DateTime, ForeignKey, Index, Integer, String, Text, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column + +from deerflow.persistence.base import Base + + +def _utc_now() -> datetime: + return datetime.now(UTC) + + +class ChannelConnectionRow(Base): + __tablename__ = "channel_connections" + + id: Mapped[str] = mapped_column(String(64), primary_key=True) + owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True) + provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True) + status: Mapped[str] = mapped_column(String(32), nullable=False, default="connected") + + external_account_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") + external_account_name: Mapped[str | None] = mapped_column(String(256), nullable=True) + workspace_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") + workspace_name: Mapped[str | None] = mapped_column(String(256), nullable=True) + bot_user_id: Mapped[str | None] = mapped_column(String(128), nullable=True) + + scopes_json: Mapped[list] = mapped_column(JSON, default=list) + capabilities_json: Mapped[dict] = mapped_column(JSON, default=dict) + metadata_json: Mapped[dict] = mapped_column(JSON, default=dict) + + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now) + last_seen_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + last_error_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + + __table_args__ = ( + UniqueConstraint( + "owner_user_id", + "provider", + "external_account_id", + "workspace_id", + name="uq_channel_connection_owner_provider_identity", + ), + Index("idx_channel_connections_event_lookup", "provider", "workspace_id", "bot_user_id"), + ) + + +class ChannelCredentialRow(Base): + __tablename__ = "channel_credentials" + + connection_id: Mapped[str] = mapped_column( + String(64), + ForeignKey("channel_connections.id", ondelete="CASCADE"), + primary_key=True, + ) + encrypted_access_token: Mapped[str | None] = mapped_column(Text, nullable=True) + encrypted_refresh_token: Mapped[str | None] = mapped_column(Text, nullable=True) + token_type: Mapped[str | None] = mapped_column(String(32), nullable=True) + expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + refresh_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + encrypted_extra_json: Mapped[str | None] = mapped_column(Text, nullable=True) + version: Mapped[int] = mapped_column(Integer, nullable=False, default=1) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now) + + +class ChannelOAuthStateRow(Base): + __tablename__ = "channel_oauth_states" + + state_hash: Mapped[str] = mapped_column(String(128), primary_key=True) + owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True) + provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True) + code_verifier_encrypted: Mapped[str | None] = mapped_column(Text, nullable=True) + nonce_hash: Mapped[str | None] = mapped_column(String(128), nullable=True) + redirect_after: Mapped[str | None] = mapped_column(Text, nullable=True) + requested_scopes_json: Mapped[list] = mapped_column(JSON, default=list) + metadata_json: Mapped[dict] = mapped_column(JSON, default=dict) + expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + consumed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now) + + +class ChannelConversationRow(Base): + __tablename__ = "channel_conversations" + + id: Mapped[str] = mapped_column(String(64), primary_key=True) + connection_id: Mapped[str] = mapped_column( + String(64), + ForeignKey("channel_connections.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True) + provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True) + external_conversation_id: Mapped[str] = mapped_column(String(128), nullable=False) + external_topic_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") + thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now) + + __table_args__ = ( + UniqueConstraint( + "connection_id", + "external_conversation_id", + "external_topic_id", + name="uq_channel_conversation_connection_external", + ), + ) diff --git a/backend/packages/harness/deerflow/persistence/channel_connections/sql.py b/backend/packages/harness/deerflow/persistence/channel_connections/sql.py new file mode 100644 index 000000000..4739fd3c9 --- /dev/null +++ b/backend/packages/harness/deerflow/persistence/channel_connections/sql.py @@ -0,0 +1,387 @@ +"""SQL repository for user-owned IM channel connections.""" + +from __future__ import annotations + +import base64 +import hashlib +import json +import logging +import uuid +from datetime import UTC, datetime +from typing import Any + +from cryptography.fernet import Fernet, InvalidToken +from sqlalchemy import delete, func, select, update +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from deerflow.persistence.channel_connections.model import ( + ChannelConnectionRow, + ChannelConversationRow, + ChannelCredentialRow, + ChannelOAuthStateRow, +) +from deerflow.utils.time import coerce_iso + +logger = logging.getLogger(__name__) + + +class ChannelCredentialCipher: + """Encrypts provider credentials before they are persisted.""" + + def __init__(self, fernet: Fernet) -> None: + self._fernet = fernet + + @classmethod + def from_key(cls, key: str) -> ChannelCredentialCipher: + digest = hashlib.sha256(key.encode("utf-8")).digest() + return cls(Fernet(base64.urlsafe_b64encode(digest))) + + def encrypt_text(self, value: str | None) -> str | None: + if value is None: + return None + return "fernet:v1:" + self._fernet.encrypt(value.encode("utf-8")).decode("ascii") + + def decrypt_text(self, value: str | None) -> str | None: + if value is None: + return None + token = value.removeprefix("fernet:v1:") + return self._fernet.decrypt(token.encode("ascii")).decode("utf-8") + + +class ChannelConnectionRepository: + """Persistence facade for channel connections, credentials, and conversations.""" + + def __init__( + self, + session_factory: async_sessionmaker[AsyncSession], + *, + cipher: ChannelCredentialCipher | None = None, + ) -> None: + self.session_factory = session_factory + self._cipher = cipher + + async def close(self) -> None: + from deerflow.persistence.engine import close_engine + + await close_engine() + + @staticmethod + def _new_id() -> str: + return uuid.uuid4().hex + + @staticmethod + def _normalize_optional_identity(value: str | None) -> str: + return value or "" + + @staticmethod + def _coerce_datetime(value: datetime | None) -> datetime | None: + if value is None or value.tzinfo is not None: + return value + return value.replace(tzinfo=UTC) + + def _encrypt_optional_secret(self, value: str | None) -> str | None: + if value is None: + return None + if self._cipher is None: + raise RuntimeError("channel connection encryption key is required") + return self._cipher.encrypt_text(value) + + @staticmethod + def _connection_to_dict(row: ChannelConnectionRow) -> dict[str, Any]: + data = row.to_dict() + data["external_account_id"] = data["external_account_id"] or None + data["workspace_id"] = data["workspace_id"] or None + data["scopes"] = data.pop("scopes_json") or [] + data["capabilities"] = data.pop("capabilities_json") or {} + data["metadata"] = data.pop("metadata_json") or {} + for key in ("created_at", "updated_at", "last_seen_at", "last_error_at"): + value = data.get(key) + if isinstance(value, datetime): + data[key] = coerce_iso(value) + return data + + async def upsert_connection( + self, + *, + owner_user_id: str, + provider: str, + external_account_id: str | None = None, + external_account_name: str | None = None, + workspace_id: str | None = None, + workspace_name: str | None = None, + bot_user_id: str | None = None, + scopes: list[str] | None = None, + capabilities: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + status: str = "connected", + ) -> dict[str, Any]: + external_account_id_value = self._normalize_optional_identity(external_account_id) + workspace_id_value = self._normalize_optional_identity(workspace_id) + + def _apply(row: ChannelConnectionRow) -> None: + row.status = status + row.external_account_name = external_account_name + row.workspace_name = workspace_name + row.bot_user_id = bot_user_id + row.scopes_json = list(scopes or []) + row.capabilities_json = dict(capabilities or {}) + row.metadata_json = dict(metadata or {}) + + stmt = select(ChannelConnectionRow).where( + ChannelConnectionRow.owner_user_id == owner_user_id, + ChannelConnectionRow.provider == provider, + ChannelConnectionRow.external_account_id == external_account_id_value, + ChannelConnectionRow.workspace_id == workspace_id_value, + ) + async with self.session_factory() as session: + row = (await session.execute(stmt)).scalar_one_or_none() + if row is None: + row = ChannelConnectionRow( + id=self._new_id(), + owner_user_id=owner_user_id, + provider=provider, + external_account_id=external_account_id_value, + workspace_id=workspace_id_value, + ) + session.add(row) + + _apply(row) + try: + await session.commit() + except IntegrityError: + # A concurrent writer inserted the same identity first; retry as + # an update of that row. + await session.rollback() + row = (await session.execute(stmt)).scalar_one() + _apply(row) + await session.commit() + await session.refresh(row) + return self._connection_to_dict(row) + + async def list_connections(self, owner_user_id: str) -> list[dict[str, Any]]: + async with self.session_factory() as session: + result = await session.execute(select(ChannelConnectionRow).where(ChannelConnectionRow.owner_user_id == owner_user_id).order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc())) + return [self._connection_to_dict(row) for row in result.scalars()] + + async def disconnect_connection(self, *, connection_id: str, owner_user_id: str) -> bool: + async with self.session_factory() as session: + row = await session.get(ChannelConnectionRow, connection_id) + if row is None or row.owner_user_id != owner_user_id: + return False + + row.status = "revoked" + credential = await session.get(ChannelCredentialRow, connection_id) + if credential is not None: + await session.delete(credential) + await session.commit() + return True + + async def store_credentials( + self, + connection_id: str, + *, + access_token: str | None, + refresh_token: str | None = None, + token_type: str | None = None, + expires_at: datetime | None = None, + refresh_expires_at: datetime | None = None, + extra: dict[str, Any] | None = None, + ) -> None: + if self._cipher is None: + raise RuntimeError("channel connection encryption key is required") + async with self.session_factory() as session: + row = await session.get(ChannelCredentialRow, connection_id) + if row is None: + row = ChannelCredentialRow(connection_id=connection_id) + session.add(row) + row.encrypted_access_token = self._cipher.encrypt_text(access_token) + row.encrypted_refresh_token = self._cipher.encrypt_text(refresh_token) + row.token_type = token_type + row.expires_at = expires_at + row.refresh_expires_at = refresh_expires_at + row.encrypted_extra_json = self._cipher.encrypt_text(json.dumps(extra or {}, ensure_ascii=False)) + row.version = (row.version or 0) + 1 + await session.commit() + + async def get_credentials(self, connection_id: str) -> dict[str, Any] | None: + if self._cipher is None: + return None + async with self.session_factory() as session: + row = await session.get(ChannelCredentialRow, connection_id) + if row is None: + return None + try: + extra_raw = self._cipher.decrypt_text(row.encrypted_extra_json) + return { + "connection_id": row.connection_id, + "access_token": self._cipher.decrypt_text(row.encrypted_access_token), + "refresh_token": self._cipher.decrypt_text(row.encrypted_refresh_token), + "token_type": row.token_type, + "expires_at": self._coerce_datetime(row.expires_at), + "refresh_expires_at": self._coerce_datetime(row.refresh_expires_at), + "extra": json.loads(extra_raw) if extra_raw else {}, + } + except (InvalidToken, UnicodeError, json.JSONDecodeError): + logger.warning( + "Unable to decrypt channel connection credentials; treating credentials as unavailable", + exc_info=True, + ) + return None + + @staticmethod + def hash_state(state: str) -> str: + return hashlib.sha256(state.encode("utf-8")).hexdigest() + + async def create_oauth_state( + self, + *, + owner_user_id: str, + provider: str, + state: str, + expires_at: datetime, + code_verifier: str | None = None, + nonce_hash: str | None = None, + redirect_after: str | None = None, + requested_scopes: list[str] | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + row = ChannelOAuthStateRow( + state_hash=self.hash_state(state), + owner_user_id=owner_user_id, + provider=provider, + code_verifier_encrypted=self._encrypt_optional_secret(code_verifier), + nonce_hash=nonce_hash, + redirect_after=redirect_after, + requested_scopes_json=list(requested_scopes or []), + metadata_json=dict(metadata or {}), + expires_at=expires_at, + ) + async with self.session_factory() as session: + session.add(row) + await session.commit() + + async def count_oauth_states(self, *, owner_user_id: str, provider: str) -> int: + async with self.session_factory() as session: + result = await session.execute( + select(func.count()) + .select_from(ChannelOAuthStateRow) + .where( + ChannelOAuthStateRow.owner_user_id == owner_user_id, + ChannelOAuthStateRow.provider == provider, + ) + ) + return int(result.scalar_one()) + + async def consume_oauth_state( + self, + *, + provider: str, + state: str, + now: datetime | None = None, + ) -> dict[str, Any] | None: + current_time = now or datetime.now(UTC) + state_hash = self.hash_state(state) + async with self.session_factory() as session: + await session.execute(delete(ChannelOAuthStateRow).where(ChannelOAuthStateRow.expires_at < current_time)) + row = await session.get(ChannelOAuthStateRow, state_hash) + if row is None or row.provider != provider or row.consumed_at is not None: + await session.commit() + return None + expires_at = self._coerce_datetime(row.expires_at) + if expires_at is not None and expires_at < current_time: + await session.commit() + return None + + # Conditional UPDATE so two concurrent workers cannot both consume + # the same binding code: only the writer that flips consumed_at + # from NULL wins. + result = await session.execute( + update(ChannelOAuthStateRow) + .where( + ChannelOAuthStateRow.state_hash == state_hash, + ChannelOAuthStateRow.consumed_at.is_(None), + ) + .values(consumed_at=current_time) + ) + await session.commit() + if result.rowcount != 1: + return None + return { + "owner_user_id": row.owner_user_id, + "provider": row.provider, + "requested_scopes": row.requested_scopes_json or [], + "metadata": row.metadata_json or {}, + "redirect_after": row.redirect_after, + } + + async def find_connection_by_external_identity( + self, + *, + provider: str, + external_account_id: str, + workspace_id: str | None = None, + ) -> dict[str, Any] | None: + async with self.session_factory() as session: + result = await session.execute( + select(ChannelConnectionRow) + .where( + ChannelConnectionRow.provider == provider, + ChannelConnectionRow.external_account_id == self._normalize_optional_identity(external_account_id), + ChannelConnectionRow.workspace_id == self._normalize_optional_identity(workspace_id), + ChannelConnectionRow.status == "connected", + ) + .order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc()) + .limit(1) + ) + row = result.scalar_one_or_none() + return self._connection_to_dict(row) if row is not None else None + + async def set_thread_id( + self, + *, + connection_id: str, + owner_user_id: str, + provider: str, + external_conversation_id: str, + thread_id: str, + external_topic_id: str | None = None, + ) -> None: + topic_id = external_topic_id or "" + async with self.session_factory() as session: + stmt = select(ChannelConversationRow).where( + ChannelConversationRow.connection_id == connection_id, + ChannelConversationRow.external_conversation_id == external_conversation_id, + ChannelConversationRow.external_topic_id == topic_id, + ) + row = (await session.execute(stmt)).scalar_one_or_none() + if row is None: + row = ChannelConversationRow( + id=self._new_id(), + connection_id=connection_id, + owner_user_id=owner_user_id, + provider=provider, + external_conversation_id=external_conversation_id, + external_topic_id=topic_id, + thread_id=thread_id, + ) + session.add(row) + else: + row.thread_id = thread_id + row.owner_user_id = owner_user_id + row.provider = provider + await session.commit() + + async def get_thread_id( + self, + connection_id: str, + external_conversation_id: str, + external_topic_id: str | None = None, + ) -> str | None: + async with self.session_factory() as session: + stmt = select(ChannelConversationRow.thread_id).where( + ChannelConversationRow.connection_id == connection_id, + ChannelConversationRow.external_conversation_id == external_conversation_id, + ChannelConversationRow.external_topic_id == (external_topic_id or ""), + ) + return (await session.execute(stmt)).scalar_one_or_none() diff --git a/backend/packages/harness/deerflow/persistence/models/__init__.py b/backend/packages/harness/deerflow/persistence/models/__init__.py index ab29a3536..40445f373 100644 --- a/backend/packages/harness/deerflow/persistence/models/__init__.py +++ b/backend/packages/harness/deerflow/persistence/models/__init__.py @@ -14,10 +14,26 @@ its storage implementation lives in ``deerflow.runtime.events.store.db`` and there is no matching entity directory. """ +from deerflow.persistence.channel_connections.model import ( + ChannelConnectionRow, + ChannelConversationRow, + ChannelCredentialRow, + ChannelOAuthStateRow, +) from deerflow.persistence.feedback.model import FeedbackRow from deerflow.persistence.models.run_event import RunEventRow from deerflow.persistence.run.model import RunRow from deerflow.persistence.thread_meta.model import ThreadMetaRow from deerflow.persistence.user.model import UserRow -__all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow", "UserRow"] +__all__ = [ + "ChannelConnectionRow", + "ChannelConversationRow", + "ChannelCredentialRow", + "ChannelOAuthStateRow", + "FeedbackRow", + "RunEventRow", + "RunRow", + "ThreadMetaRow", + "UserRow", +] diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/base.py b/backend/packages/harness/deerflow/persistence/thread_meta/base.py index ed55ade8e..4207b4daa 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/base.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/base.py @@ -71,6 +71,15 @@ class ThreadMetaStore(abc.ABC): """ pass + @abc.abstractmethod + async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: + """Move a thread metadata row to a new owner. + + Intended for trusted internal repair/migration paths. No-op if the + row does not exist or the caller fails the owner check. + """ + pass + @abc.abstractmethod async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool: """Check if ``user_id`` has access to ``thread_id``.""" diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py b/backend/packages/harness/deerflow/persistence/thread_meta/memory.py index 4f642a938..b17d994f8 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/memory.py @@ -127,6 +127,14 @@ class MemoryThreadMetaStore(ThreadMetaStore): record["updated_at"] = now_iso() await self._store.aput(THREADS_NS, thread_id, record) + async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: + record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_owner") + if record is None: + return + record["user_id"] = owner_user_id + record["updated_at"] = now_iso() + await self._store.aput(THREADS_NS, thread_id, record) + async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete") if record is None: diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py b/backend/packages/harness/deerflow/persistence/thread_meta/sql.py index 930128087..a5e7f51c5 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/sql.py @@ -211,6 +211,21 @@ class ThreadMetaRepository(ThreadMetaStore): row.updated_at = datetime.now(UTC) await session.commit() + async def update_owner( + self, + thread_id: str, + owner_user_id: str, + *, + user_id: str | None | _AutoSentinel = AUTO, + ) -> None: + """Move a thread metadata row to ``owner_user_id``.""" + resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_owner") + async with self._sf() as session: + if not await self._check_ownership(session, thread_id, resolved_user_id): + return + await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(user_id=owner_user_id, updated_at=datetime.now(UTC))) + await session.commit() + async def delete( self, thread_id: str, diff --git a/backend/packages/harness/deerflow/runtime/runs/manager.py b/backend/packages/harness/deerflow/runtime/runs/manager.py index ef45852fb..9a9082fb7 100644 --- a/backend/packages/harness/deerflow/runtime/runs/manager.py +++ b/backend/packages/harness/deerflow/runtime/runs/manager.py @@ -83,6 +83,7 @@ class RunRecord: multitask_strategy: str = "reject" metadata: dict = field(default_factory=dict) kwargs: dict = field(default_factory=dict) + user_id: str | None = None created_at: str = "" updated_at: str = "" task: asyncio.Task | None = field(default=None, repr=False) @@ -124,7 +125,7 @@ class RunManager: @staticmethod def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]: - return { + payload = { "thread_id": record.thread_id, "assistant_id": record.assistant_id, "status": record.status.value, @@ -135,6 +136,9 @@ class RunManager: "created_at": record.created_at, "model_name": record.model_name, } + if record.user_id is not None: + payload["user_id"] = record.user_id + return payload async def _call_store_with_retry( self, @@ -241,6 +245,7 @@ class RunManager: kwargs=row.get("kwargs") or {}, created_at=row.get("created_at") or "", updated_at=row.get("updated_at") or "", + user_id=row.get("user_id"), error=row.get("error"), model_name=row.get("model_name"), store_only=True, @@ -320,6 +325,7 @@ class RunManager: metadata: dict | None = None, kwargs: dict | None = None, multitask_strategy: str = "reject", + user_id: str | None = None, ) -> RunRecord: """Create a new pending run and register it.""" run_id = str(uuid.uuid4()) @@ -333,6 +339,7 @@ class RunManager: multitask_strategy=multitask_strategy, metadata=metadata or {}, kwargs=kwargs or {}, + user_id=user_id, created_at=now, updated_at=now, ) @@ -504,6 +511,7 @@ class RunManager: kwargs: dict | None = None, multitask_strategy: str = "reject", model_name: str | None = None, + user_id: str | None = None, ) -> RunRecord: """Atomically check for inflight runs and create a new one. @@ -546,6 +554,7 @@ class RunManager: multitask_strategy=multitask_strategy, metadata=metadata or {}, kwargs=kwargs or {}, + user_id=user_id, created_at=now, updated_at=now, model_name=model_name, diff --git a/backend/packages/harness/pyproject.toml b/backend/packages/harness/pyproject.toml index 47cd1afad..1d6d002d8 100644 --- a/backend/packages/harness/pyproject.toml +++ b/backend/packages/harness/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "sqlalchemy[asyncio]>=2.0,<3.0", "aiosqlite>=0.19", "alembic>=1.13", + "cryptography>=43.0.0", ] [project.optional-dependencies] diff --git a/backend/tests/blocking_io/test_channel_runtime_config_store.py b/backend/tests/blocking_io/test_channel_runtime_config_store.py new file mode 100644 index 000000000..eaafad49b --- /dev/null +++ b/backend/tests/blocking_io/test_channel_runtime_config_store.py @@ -0,0 +1,106 @@ +"""Regression anchors: channel runtime-config handlers must not block the event loop. + +``configure_channel_provider_runtime`` and ``disconnect_channel_provider_runtime`` +persist UI-entered channel credentials through ``ChannelRuntimeConfigStore``, +whose construction reads its JSON file and whose setters rewrite it +(``json.dump`` + ``Path.replace`` + ``chmod``). The handlers offload both via +``asyncio.to_thread``; if that regresses back onto the event loop, the strict +Blockbuster gate raises ``BlockingError`` and these tests fail. + +The handlers are invoked directly with a minimal Starlette ``Request`` so the +surface under test is exactly the router's own IO, mirroring +``test_agents_router``. Test-side seeding/inspection is offloaded with +``asyncio.to_thread``. +""" + +from __future__ import annotations + +import asyncio +import importlib +from types import SimpleNamespace +from uuid import UUID + +import pytest +from fastapi import FastAPI, Request + +from app.channels.runtime_config_store import ChannelRuntimeConfigStore +from app.gateway.routers.channel_connections import ( + ChannelRuntimeConfigRequest, + configure_channel_provider_runtime, + disconnect_channel_provider_runtime, +) +from deerflow.config.app_config import AppConfig, reset_app_config, set_app_config +from deerflow.config.channel_connections_config import ChannelConnectionsConfig + +# Pre-import: the handlers import this module lazily; the import's file IO +# must happen at collection time, not on the event loop under the gate. +importlib.import_module("app.channels.service") + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture(autouse=True) +def _stub_app_config(): + set_app_config(AppConfig.model_validate({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}})) + yield + reset_app_config() + + +def _make_request(tmp_path) -> Request: + app = FastAPI() + app.state.channel_connections_config = ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "slack": {"enabled": True}, + } + ) + app.state.channels_config = {} + app.state.channel_connection_repo = _FakeRepo() + store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json") + app.state.channel_runtime_config_store = store + user = SimpleNamespace(id=UUID("11111111-2222-3333-4444-555555555555"), system_role="admin") + return Request({"type": "http", "app": app, "headers": [], "state": {"user": user}}) + + +class _FakeRepo: + async def list_connections(self, owner_user_id): + return [] + + +async def test_configure_runtime_channel_does_not_block_event_loop(tmp_path) -> None: + request = await asyncio.to_thread(_make_request, tmp_path) + + response = await configure_channel_provider_runtime( + "slack", + ChannelRuntimeConfigRequest(values={"bot_token": "xoxb-ui", "app_token": "xapp-ui"}), + request, + ) + + assert response.provider == "slack" + store = request.app.state.channel_runtime_config_store + assert await asyncio.to_thread(store.get_provider_config, "slack") == { + "enabled": True, + "bot_token": "xoxb-ui", + "app_token": "xapp-ui", + } + + +async def test_disconnect_runtime_channel_does_not_block_event_loop(tmp_path) -> None: + request = await asyncio.to_thread(_make_request, tmp_path) + store = request.app.state.channel_runtime_config_store + await asyncio.to_thread( + store.set_provider_config, + "slack", + {"enabled": True, "bot_token": "xoxb-ui", "app_token": "xapp-ui"}, + ) + request.app.state.channels_config = { + "slack": {"enabled": True, "bot_token": "xoxb-ui", "app_token": "xapp-ui"}, + } + + response = await disconnect_channel_provider_runtime("slack", request) + + assert response.provider == "slack" + assert await asyncio.to_thread(store.get_provider_config, "slack") == { + "enabled": False, + "_runtime_disabled": True, + } diff --git a/backend/tests/test_additional_channel_connections.py b/backend/tests/test_additional_channel_connections.py new file mode 100644 index 000000000..d6134a355 --- /dev/null +++ b/backend/tests/test_additional_channel_connections.py @@ -0,0 +1,251 @@ +"""Connection binding tests for browser-connectable IM channels beyond Telegram/Slack/Discord.""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock, MagicMock + +from app.channels.message_bus import InboundMessage, MessageBus + + +async def _make_repo(tmp_path, name: str): + from deerflow.persistence.channel_connections import ChannelConnectionRepository + from deerflow.persistence.engine import get_session_factory, init_engine + + await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / f'{name}.db'}", sqlite_dir=str(tmp_path)) + return ChannelConnectionRepository(get_session_factory()) + + +async def _seed_state(repo, provider: str, state: str, owner_user_id: str = "deerflow-user-1") -> None: + await repo.create_oauth_state( + owner_user_id=owner_user_id, + provider=provider, + state=state, + expires_at=datetime.now(UTC) + timedelta(minutes=5), + ) + + +def test_feishu_connect_command_binds_identity(tmp_path): + import anyio + + from app.channels.feishu import FeishuChannel + + async def go(): + repo = await _make_repo(tmp_path, "feishu") + state = "feishu-bind-code" + await _seed_state(repo, "feishu", state) + channel = FeishuChannel( + bus=MessageBus(), + config={"app_id": "app", "app_secret": "secret", "connection_repo": repo}, + ) + channel._reply_card = AsyncMock() + + handled = await channel._bind_connection_from_connect_code( + message_id="om-message-1", + chat_id="oc-chat-1", + user_id="ou-user-1", + code=state, + ) + + connections = await repo.list_connections("deerflow-user-1") + assert handled is True + assert len(connections) == 1 + assert connections[0]["provider"] == "feishu" + assert connections[0]["external_account_id"] == "ou-user-1" + assert connections[0]["workspace_id"] == "oc-chat-1" + channel._reply_card.assert_awaited_once_with("om-message-1", "Feishu connected to DeerFlow.") + await repo.close() + + anyio.run(go) + + +def test_dingtalk_connect_command_binds_identity(tmp_path): + import anyio + + from app.channels.dingtalk import _CONVERSATION_TYPE_GROUP, DingTalkChannel + + async def go(): + repo = await _make_repo(tmp_path, "dingtalk") + state = "dingtalk-bind-code" + await _seed_state(repo, "dingtalk", state) + channel = DingTalkChannel( + bus=MessageBus(), + config={"client_id": "client", "client_secret": "secret", "connection_repo": repo}, + ) + channel._send_connection_reply = AsyncMock() + + handled = await channel._bind_connection_from_connect_code( + conversation_type=_CONVERSATION_TYPE_GROUP, + sender_staff_id="staff-user-1", + sender_nick="Alice", + conversation_id="cid-group-1", + code=state, + ) + + connections = await repo.list_connections("deerflow-user-1") + assert handled is True + assert len(connections) == 1 + assert connections[0]["provider"] == "dingtalk" + assert connections[0]["external_account_id"] == "staff-user-1" + assert connections[0]["external_account_name"] == "Alice" + assert connections[0]["workspace_id"] == "cid-group-1" + channel._send_connection_reply.assert_awaited_once() + await repo.close() + + anyio.run(go) + + +def test_wechat_connect_command_binds_identity(tmp_path): + import anyio + + from app.channels.wechat import WechatChannel + + async def go(): + repo = await _make_repo(tmp_path, "wechat") + state = "wechat-bind-code" + await _seed_state(repo, "wechat", state) + channel = WechatChannel( + bus=MessageBus(), + config={"bot_token": "token", "connection_repo": repo}, + ) + channel._send_connection_reply = AsyncMock() + + handled = await channel._bind_connection_from_connect_code( + chat_id="wx-user-1", + context_token="ctx-1", + code=state, + ) + + connections = await repo.list_connections("deerflow-user-1") + assert handled is True + assert len(connections) == 1 + assert connections[0]["provider"] == "wechat" + assert connections[0]["external_account_id"] == "wx-user-1" + assert connections[0]["workspace_id"] == "wx-user-1" + channel._send_connection_reply.assert_awaited_once_with("wx-user-1", "ctx-1", "WeChat connected to DeerFlow.") + await repo.close() + + anyio.run(go) + + +def test_wecom_connect_command_binds_identity(tmp_path): + import anyio + + from app.channels.wecom import WeComChannel + + async def go(): + repo = await _make_repo(tmp_path, "wecom") + state = "wecom-bind-code" + await _seed_state(repo, "wecom", state) + channel = WeComChannel( + bus=MessageBus(), + config={"bot_id": "bot", "bot_secret": "secret", "connection_repo": repo}, + ) + channel._ws_client = MagicMock() + channel._ws_client.reply = AsyncMock() + frame = {"body": {"aibotid": "bot-1", "chattype": "single"}} + + handled = await channel._bind_connection_from_connect_code( + frame=frame, + user_id="wecom-user-1", + code=state, + ) + + connections = await repo.list_connections("deerflow-user-1") + assert handled is True + assert len(connections) == 1 + assert connections[0]["provider"] == "wecom" + assert connections[0]["external_account_id"] == "wecom-user-1" + assert connections[0]["workspace_id"] == "bot-1" + channel._ws_client.reply.assert_awaited_once_with(frame, {"msgtype": "text", "text": {"content": "WeCom connected to DeerFlow."}}) + await repo.close() + + anyio.run(go) + + +def test_additional_channels_attach_owner_identity(tmp_path): + import anyio + + from app.channels.dingtalk import _CONVERSATION_TYPE_GROUP, DingTalkChannel + from app.channels.feishu import FeishuChannel + from app.channels.wechat import WechatChannel + from app.channels.wecom import WeComChannel + + async def go(): + repo = await _make_repo(tmp_path, "additional-identity") + await repo.upsert_connection( + owner_user_id="deerflow-user-1", + provider="feishu", + external_account_id="ou-user-1", + workspace_id="oc-chat-1", + ) + await repo.upsert_connection( + owner_user_id="deerflow-user-1", + provider="dingtalk", + external_account_id="staff-user-1", + workspace_id="cid-group-1", + ) + await repo.upsert_connection( + owner_user_id="deerflow-user-1", + provider="wechat", + external_account_id="wx-user-1", + workspace_id="wx-user-1", + ) + await repo.upsert_connection( + owner_user_id="deerflow-user-1", + provider="wecom", + external_account_id="wecom-user-1", + workspace_id="bot-1", + ) + + cases = [ + ( + FeishuChannel(bus=MessageBus(), config={"connection_repo": repo}), + InboundMessage(channel_name="feishu", chat_id="oc-chat-1", user_id="ou-user-1", text="hello"), + ), + ( + DingTalkChannel(bus=MessageBus(), config={"connection_repo": repo}), + InboundMessage( + channel_name="dingtalk", + chat_id="cid-group-1", + user_id="staff-user-1", + text="hello", + metadata={ + "conversation_type": _CONVERSATION_TYPE_GROUP, + "conversation_id": "cid-group-1", + }, + ), + ), + ( + WechatChannel(bus=MessageBus(), config={"connection_repo": repo}), + InboundMessage(channel_name="wechat", chat_id="wx-user-1", user_id="wx-user-1", text="hello"), + ), + ( + WeComChannel(bus=MessageBus(), config={"connection_repo": repo}), + InboundMessage( + channel_name="wecom", + chat_id="wecom-user-1", + user_id="wecom-user-1", + text="hello", + metadata={"aibotid": "bot-1"}, + ), + ), + ] + + for channel, inbound in cases: + attached = await channel._attach_connection_identity(inbound) + assert attached.owner_user_id == "deerflow-user-1" + assert attached.connection_id + assert ( + attached.workspace_id + == { + "feishu": "oc-chat-1", + "dingtalk": "cid-group-1", + "wechat": "wx-user-1", + "wecom": "bot-1", + }[channel.name] + ) + + await repo.close() + + anyio.run(go) diff --git a/backend/tests/test_auth.py b/backend/tests/test_auth.py index f19c83c7d..b16546fdf 100644 --- a/backend/tests/test_auth.py +++ b/backend/tests/test_auth.py @@ -280,6 +280,74 @@ def test_require_permission_denies_wrong_permission(): assert "Permission denied" in response.json()["detail"] +def _make_internal_owner_check_app(): + """App with an owner_check route and a thread owned by ``alice``.""" + import asyncio + + from fastapi import Request + from langgraph.store.memory import InMemoryStore + + from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore + + app = FastAPI() + thread_store = MemoryThreadMetaStore(InMemoryStore()) + asyncio.run(thread_store.create("alice-thread", user_id="alice")) + app.state.thread_store = thread_store + + @app.get("/threads/{thread_id}") + @require_permission("threads", "read", owner_check=True) + async def endpoint(thread_id: str, request: Request): + return {"ok": True} + + return app + + +def _internal_auth_context() -> AuthContext: + from types import SimpleNamespace + + from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE + + user = SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE) + return AuthContext(user=user, permissions=[Permissions.THREADS_READ]) + + +def test_require_permission_internal_role_scoped_by_owner_header(): + """An internal caller acting for the thread owner passes the owner check.""" + from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME + + app = _make_internal_owner_check_app() + with patch("app.gateway.authz._authenticate", return_value=_internal_auth_context()): + with TestClient(app) as client: + response = client.get( + "/threads/alice-thread", + headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "alice"}, + ) + assert response.status_code == 200 + + +def test_require_permission_internal_role_denied_for_other_owner(): + """The internal token must not grant access to another user's thread.""" + from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME + + app = _make_internal_owner_check_app() + with patch("app.gateway.authz._authenticate", return_value=_internal_auth_context()): + with TestClient(app) as client: + response = client.get( + "/threads/alice-thread", + headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "mallory"}, + ) + assert response.status_code == 404 + + +def test_require_permission_internal_role_without_header_is_scoped_to_internal_user(): + """With no owner header, internal callers are scoped like before the bypass.""" + app = _make_internal_owner_check_app() + with patch("app.gateway.authz._authenticate", return_value=_internal_auth_context()): + with TestClient(app) as client: + response = client.get("/threads/alice-thread") + assert response.status_code == 404 + + # ── Weak JWT secret warning ────────────────────────────────────────────────── diff --git a/backend/tests/test_auth_middleware.py b/backend/tests/test_auth_middleware.py index 838bf57af..ab2e817eb 100644 --- a/backend/tests/test_auth_middleware.py +++ b/backend/tests/test_auth_middleware.py @@ -39,6 +39,8 @@ def test_public_paths(path: str): "/api/threads/123/uploads", "/api/agents", "/api/channels", + "/api/channels/providers", + "/api/channels/slack/connect", "/api/runs/stream", "/api/threads/123/runs", "/api/v1/auth/me", @@ -183,7 +185,7 @@ def _make_auth_csrf_app(): @pytest.fixture def client(monkeypatch): - monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False) + monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "") return TestClient(_make_app()) @@ -221,7 +223,7 @@ def test_auth_disabled_allows_protected_path_without_cookie(monkeypatch): assert res.json() == {"models": []} -def test_auth_disabled_stamps_e2e_admin_user_without_cookie(monkeypatch): +def test_auth_disabled_stamps_default_admin_user_without_cookie(monkeypatch): monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1") client = TestClient(_make_app()) @@ -229,10 +231,10 @@ def test_auth_disabled_stamps_e2e_admin_user_without_cookie(monkeypatch): assert res.status_code == 200 assert res.json() == { - "id": "e2e-user", - "email": "e2e@test.local", + "id": "default", + "email": "default@test.local", "system_role": "admin", - "context_user_id": "e2e-user", + "context_user_id": "default", } @@ -244,8 +246,8 @@ def test_auth_disabled_auth_me_reuses_middleware_user_without_cookie(monkeypatch assert res.status_code == 200 assert res.json() == { - "id": "e2e-user", - "email": "e2e@test.local", + "id": "default", + "email": "default@test.local", "system_role": "admin", "needs_setup": False, } @@ -329,7 +331,7 @@ def test_auth_disabled_startup_warning_when_effective(monkeypatch, caplog): warn_if_auth_disabled_enabled() assert "authentication is bypassed" in caplog.text - assert "e2e-user" in caplog.text + assert "default" in caplog.text def test_auth_disabled_startup_warning_suppressed_in_explicit_production_env(monkeypatch, caplog): @@ -348,7 +350,8 @@ def test_protected_path_with_junk_cookie_rejected(client): """Junk cookie → 401. Middleware strictly validates the JWT now (AUTH_TEST_PLAN test 7.5.8); it no longer silently passes bad tokens through to the route handler.""" - res = client.get("/api/models", cookies={"access_token": "some-token"}) + client.cookies.set("access_token", "some-token") + res = client.get("/api/models") assert res.status_code == 401 diff --git a/backend/tests/test_channel_connections_config.py b/backend/tests/test_channel_connections_config.py new file mode 100644 index 000000000..8a14878c0 --- /dev/null +++ b/backend/tests/test_channel_connections_config.py @@ -0,0 +1,56 @@ +"""Tests for user-facing IM channel connection configuration.""" + +from deerflow.config.channel_connections_config import ChannelConnectionsConfig + + +def test_channel_connections_disabled_by_default(): + config = ChannelConnectionsConfig() + + assert config.enabled is False + assert config.slack.enabled is False + assert config.telegram.enabled is False + assert config.discord.enabled is False + assert config.feishu.enabled is False + assert config.dingtalk.enabled is False + assert config.wechat.enabled is False + assert config.wecom.enabled is False + + +def test_enabled_channel_connections_do_not_require_public_url_or_encryption_key(): + config = ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "telegram": { + "enabled": True, + "bot_username": "deerflow_bot", + }, + "slack": {"enabled": True}, + "discord": {"enabled": True}, + "feishu": {"enabled": True}, + "dingtalk": {"enabled": True}, + "wechat": {"enabled": True}, + "wecom": {"enabled": True}, + } + ) + + assert config.enabled is True + assert config.provider_status("telegram") == {"enabled": True, "configured": True} + assert config.provider_status("slack") == {"enabled": True, "configured": True} + assert config.provider_status("discord") == {"enabled": True, "configured": True} + assert config.provider_status("feishu") == {"enabled": True, "configured": True} + assert config.provider_status("dingtalk") == {"enabled": True, "configured": True} + assert config.provider_status("wechat") == {"enabled": True, "configured": True} + assert config.provider_status("wecom") == {"enabled": True, "configured": True} + + +def test_provider_status_reports_disabled_and_unknown_providers(): + config = ChannelConnectionsConfig.model_validate({"enabled": True}) + + assert config.provider_status("slack") == {"enabled": False, "configured": False} + assert config.provider_status("telegram") == {"enabled": False, "configured": False} + assert config.provider_status("discord") == {"enabled": False, "configured": False} + assert config.provider_status("feishu") == {"enabled": False, "configured": False} + assert config.provider_status("dingtalk") == {"enabled": False, "configured": False} + assert config.provider_status("wechat") == {"enabled": False, "configured": False} + assert config.provider_status("wecom") == {"enabled": False, "configured": False} + assert config.provider_status("unknown") == {"enabled": False, "configured": False} diff --git a/backend/tests/test_channel_connections_repository.py b/backend/tests/test_channel_connections_repository.py new file mode 100644 index 000000000..ae5610f89 --- /dev/null +++ b/backend/tests/test_channel_connections_repository.py @@ -0,0 +1,331 @@ +"""Tests for per-user IM channel connection persistence.""" + +from __future__ import annotations + +import logging +from datetime import UTC, datetime, timedelta + +import pytest +from sqlalchemy import select + +from deerflow.persistence.channel_connections import ( + ChannelConnectionRepository, + ChannelConnectionRow, + ChannelCredentialCipher, + ChannelCredentialRow, + ChannelOAuthStateRow, +) + + +@pytest.fixture +async def repo(tmp_path): + from deerflow.persistence.engine import close_engine, get_session_factory, init_engine + + url = f"sqlite+aiosqlite:///{tmp_path / 'channels.db'}" + await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) + try: + yield ChannelConnectionRepository( + get_session_factory(), + cipher=ChannelCredentialCipher.from_key("test-encryption-key"), + ) + finally: + await close_engine() + + +class TestChannelConnectionRepository: + @pytest.mark.anyio + async def test_connections_are_listed_per_owner(self, repo): + alice = await repo.upsert_connection( + owner_user_id="alice", + provider="slack", + external_account_id="U-alice", + external_account_name="Alice", + workspace_id="T1", + workspace_name="Team One", + scopes=["chat:write"], + ) + await repo.upsert_connection( + owner_user_id="bob", + provider="slack", + external_account_id="U-bob", + external_account_name="Bob", + workspace_id="T1", + workspace_name="Team One", + scopes=["chat:write"], + ) + + results = await repo.list_connections("alice") + + assert [item["id"] for item in results] == [alice["id"]] + assert results[0]["owner_user_id"] == "alice" + assert results[0]["provider"] == "slack" + assert results[0]["scopes"] == ["chat:write"] + assert "encrypted_access_token" not in results[0] + + @pytest.mark.anyio + async def test_upsert_connection_updates_existing_provider_identity(self, repo): + first = await repo.upsert_connection( + owner_user_id="alice", + provider="telegram", + external_account_id="42", + external_account_name="Alice", + workspace_id=None, + workspace_name=None, + status="pending", + ) + second = await repo.upsert_connection( + owner_user_id="alice", + provider="telegram", + external_account_id="42", + external_account_name="Alice Telegram", + workspace_id=None, + workspace_name=None, + status="connected", + ) + + assert second["id"] == first["id"] + assert second["status"] == "connected" + assert second["external_account_name"] == "Alice Telegram" + assert len(await repo.list_connections("alice")) == 1 + + @pytest.mark.anyio + async def test_credentials_are_encrypted_at_rest_and_decrypted_by_repository(self, repo): + connection = await repo.upsert_connection( + owner_user_id="alice", + provider="slack", + external_account_id="U-alice", + workspace_id="T1", + ) + expires_at = datetime.now(UTC) + timedelta(hours=1) + + await repo.store_credentials( + connection["id"], + access_token="xoxb-secret-access-token", + refresh_token="secret-refresh-token", + token_type="Bearer", + expires_at=expires_at, + extra={"bot_user_id": "B123"}, + ) + + async with repo.session_factory() as session: + row = (await session.execute(select(ChannelCredentialRow))).scalar_one() + assert row.encrypted_access_token is not None + assert "xoxb-secret-access-token" not in row.encrypted_access_token + assert "secret-refresh-token" not in (row.encrypted_refresh_token or "") + assert "B123" not in (row.encrypted_extra_json or "") + + credentials = await repo.get_credentials(connection["id"]) + + assert credentials is not None + assert credentials["access_token"] == "xoxb-secret-access-token" + assert credentials["refresh_token"] == "secret-refresh-token" + assert credentials["token_type"] == "Bearer" + assert credentials["expires_at"] == expires_at + assert credentials["extra"] == {"bot_user_id": "B123"} + + @pytest.mark.anyio + async def test_get_credentials_returns_none_when_decryption_fails(self, repo, caplog): + connection = await repo.upsert_connection( + owner_user_id="alice", + provider="slack", + external_account_id="U-alice", + workspace_id="T1", + ) + await repo.store_credentials(connection["id"], access_token="xoxb-secret-access-token") + wrong_key_repo = ChannelConnectionRepository( + repo.session_factory, + cipher=ChannelCredentialCipher.from_key("wrong-encryption-key"), + ) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.channel_connections.sql"): + credentials = await wrong_key_repo.get_credentials(connection["id"]) + + assert credentials is None + assert any("Unable to decrypt channel connection credentials" in record.message for record in caplog.records) + + @pytest.mark.anyio + async def test_conversations_are_scoped_by_connection(self, repo): + alice = await repo.upsert_connection( + owner_user_id="alice", + provider="slack", + external_account_id="U-alice", + workspace_id="T1", + ) + bob = await repo.upsert_connection( + owner_user_id="bob", + provider="slack", + external_account_id="U-bob", + workspace_id="T1", + ) + + await repo.set_thread_id( + connection_id=alice["id"], + owner_user_id="alice", + provider="slack", + external_conversation_id="C-shared", + external_topic_id="1710000000.000100", + thread_id="thread-alice", + ) + await repo.set_thread_id( + connection_id=bob["id"], + owner_user_id="bob", + provider="slack", + external_conversation_id="C-shared", + external_topic_id="1710000000.000100", + thread_id="thread-bob", + ) + + assert await repo.get_thread_id(alice["id"], "C-shared", "1710000000.000100") == "thread-alice" + assert await repo.get_thread_id(bob["id"], "C-shared", "1710000000.000100") == "thread-bob" + + @pytest.mark.anyio + async def test_disconnect_connection_revokes_owner_connection_and_removes_credentials(self, repo): + connection = await repo.upsert_connection( + owner_user_id="alice", + provider="telegram", + external_account_id="42", + ) + await repo.store_credentials(connection["id"], access_token="secret-token") + + disconnected = await repo.disconnect_connection( + connection_id=connection["id"], + owner_user_id="alice", + ) + + assert disconnected is True + async with repo.session_factory() as session: + connection_row = await session.get(ChannelConnectionRow, connection["id"]) + credential_row = await session.get(ChannelCredentialRow, connection["id"]) + assert connection_row is not None + assert connection_row.status == "revoked" + assert credential_row is None + assert ( + await repo.find_connection_by_external_identity( + provider="telegram", + external_account_id="42", + ) + is None + ) + + @pytest.mark.anyio + async def test_disconnect_connection_is_owner_scoped(self, repo): + connection = await repo.upsert_connection( + owner_user_id="alice", + provider="telegram", + external_account_id="42", + ) + + disconnected = await repo.disconnect_connection( + connection_id=connection["id"], + owner_user_id="bob", + ) + + assert disconnected is False + assert (await repo.list_connections("alice"))[0]["status"] == "connected" + + @pytest.mark.anyio + async def test_consume_oauth_state_deletes_expired_states(self, repo): + now = datetime.now(UTC) + await repo.create_oauth_state( + owner_user_id="alice", + provider="slack", + state="expired-state", + expires_at=now - timedelta(minutes=1), + ) + await repo.create_oauth_state( + owner_user_id="alice", + provider="slack", + state="active-state", + expires_at=now + timedelta(minutes=5), + ) + + consumed = await repo.consume_oauth_state(provider="slack", state="expired-state", now=now) + + assert consumed is None + async with repo.session_factory() as session: + states = (await session.execute(select(ChannelOAuthStateRow))).scalars().all() + assert [state.state_hash for state in states] == [repo.hash_state("active-state")] + + @pytest.mark.anyio + async def test_consume_oauth_state_is_one_time_even_under_concurrent_consumers(self, repo): + import anyio + + now = datetime.now(UTC) + await repo.create_oauth_state( + owner_user_id="alice", + provider="slack", + state="bind-once", + expires_at=now + timedelta(minutes=5), + ) + + results: list = [] + + async def consume(): + results.append(await repo.consume_oauth_state(provider="slack", state="bind-once", now=now)) + + async with anyio.create_task_group() as tg: + tg.start_soon(consume) + tg.start_soon(consume) + + consumed = [result for result in results if result is not None] + assert len(consumed) == 1 + assert consumed[0]["owner_user_id"] == "alice" + + @pytest.mark.anyio + async def test_upsert_connection_retries_as_update_when_concurrent_insert_wins(self, repo): + """A losing concurrent INSERT retries as an UPDATE instead of raising IntegrityError.""" + first = await repo.upsert_connection( + owner_user_id="alice", + provider="slack", + external_account_id="U-race", + workspace_id="T-race", + status="pending", + ) + + real_factory = repo.session_factory + + class _EmptyResult: + @staticmethod + def scalar_one_or_none(): + return None + + class MissFirstSelectSession: + """Make the initial identity SELECT miss, as if a concurrent writer inserted after it.""" + + def __init__(self, session): + self._session = session + self._missed = False + + def __getattr__(self, name): + return getattr(self._session, name) + + async def execute(self, *args, **kwargs): + result = await self._session.execute(*args, **kwargs) + if not self._missed: + self._missed = True + return _EmptyResult() + return result + + async def __aenter__(self): + await self._session.__aenter__() + return self + + async def __aexit__(self, *args): + return await self._session.__aexit__(*args) + + repo.session_factory = lambda: MissFirstSelectSession(real_factory()) + try: + second = await repo.upsert_connection( + owner_user_id="alice", + provider="slack", + external_account_id="U-race", + workspace_id="T-race", + status="connected", + ) + finally: + repo.session_factory = real_factory + + assert second["id"] == first["id"] + assert second["status"] == "connected" + connections = await repo.list_connections("alice") + assert len(connections) == 1 diff --git a/backend/tests/test_channel_connections_router.py b/backend/tests/test_channel_connections_router.py new file mode 100644 index 000000000..f4915fac8 --- /dev/null +++ b/backend/tests/test_channel_connections_router.py @@ -0,0 +1,963 @@ +"""Router tests for browser-connectable IM channels.""" + +from __future__ import annotations + +from tempfile import TemporaryDirectory +from types import SimpleNamespace +from unittest.mock import AsyncMock +from uuid import UUID + +import pytest +from _router_auth_helpers import make_authed_test_app +from fastapi.testclient import TestClient + +from app.channels.runtime_config_store import ChannelRuntimeConfigStore +from app.gateway.auth.models import User +from app.gateway.routers import channel_connections +from deerflow.config.app_config import AppConfig, reset_app_config, set_app_config +from deerflow.config.channel_connections_config import ChannelConnectionsConfig + + +@pytest.fixture(autouse=True) +def _stub_app_config(monkeypatch): + """Keep router tests independent from a developer-local config.yaml.""" + monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "0") + set_app_config(AppConfig.model_validate({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}})) + yield + reset_app_config() + + +def _user() -> User: + return User( + id=UUID("11111111-2222-3333-4444-555555555555"), + email="alice@example.com", + password_hash="x", + system_role="admin", + ) + + +def _non_admin_user() -> User: + return User( + id=UUID("99999999-8888-7777-6666-555555555555"), + email="bob@example.com", + password_hash="x", + system_role="user", + ) + + +async def _make_repo(tmp_path): + from deerflow.persistence.channel_connections import ChannelConnectionRepository + from deerflow.persistence.engine import get_session_factory, init_engine + + await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'router.db'}", sqlite_dir=str(tmp_path)) + return ChannelConnectionRepository(get_session_factory()) + + +def _make_app( + config: ChannelConnectionsConfig, + repo, + channels_config: dict | None = None, + *, + runtime_config_store: ChannelRuntimeConfigStore | None = None, + set_channels_config_state: bool = True, +): + app = make_authed_test_app(user_factory=_user) + app.state.channel_connections_config = config + app.state.channel_connection_repo = repo + if set_channels_config_state: + app.state.channels_config = channels_config or {} + if runtime_config_store is None: + runtime_config_dir = TemporaryDirectory() + app.state.channel_runtime_config_tmpdir = runtime_config_dir + runtime_config_store = ChannelRuntimeConfigStore(f"{runtime_config_dir.name}/runtime-config.json") + app.state.channel_runtime_config_store = runtime_config_store + app.include_router(channel_connections.router) + return app + + +def _enabled_connections_config() -> ChannelConnectionsConfig: + return ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "telegram": {"enabled": True, "bot_username": "deerflow_bot"}, + "slack": {"enabled": True}, + "discord": {"enabled": True}, + "feishu": {"enabled": True}, + "dingtalk": {"enabled": True}, + "wechat": {"enabled": True}, + "wecom": {"enabled": True}, + } + ) + + +def _channels_config() -> dict: + return { + "telegram": {"enabled": True, "bot_token": "telegram-token"}, + "slack": {"enabled": True, "bot_token": "xoxb-operator", "app_token": "xapp-operator"}, + "discord": {"enabled": True, "bot_token": "discord-bot"}, + "feishu": {"enabled": True, "app_id": "feishu-app", "app_secret": "feishu-secret"}, + "dingtalk": {"enabled": True, "client_id": "dingtalk-client", "client_secret": "dingtalk-secret"}, + "wechat": {"enabled": True, "bot_token": "wechat-token"}, + "wecom": {"enabled": True, "bot_id": "wecom-bot", "bot_secret": "wecom-secret"}, + } + + +def test_get_providers_only_returns_enabled_channels_and_setup_fields(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + config = ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "slack": {"enabled": True}, + "discord": {"enabled": False}, + } + ) + app = _make_app(config, repo, {}) + + with TestClient(app) as client: + response = client.get("/api/channels/providers") + + assert response.status_code == 200 + body = response.json() + assert body["enabled"] is True + assert [provider["provider"] for provider in body["providers"]] == ["slack"] + assert body["providers"][0]["configured"] is False + assert body["providers"][0]["connectable"] is False + assert body["providers"][0]["credential_fields"] == [ + { + "name": "bot_token", + "label": "Bot token", + "type": "password", + "required": True, + }, + { + "name": "app_token", + "label": "App token", + "type": "password", + "required": True, + }, + ] + + anyio.run(repo.close) + + +def test_get_providers_uses_existing_channels_config(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + app = _make_app(_enabled_connections_config(), repo, _channels_config()) + + with TestClient(app) as client: + response = client.get("/api/channels/providers") + + assert response.status_code == 200 + body = response.json() + assert body["enabled"] is True + by_provider = {item["provider"]: item for item in body["providers"]} + assert set(by_provider) == {"telegram", "slack", "discord", "feishu", "dingtalk", "wechat", "wecom"} + assert by_provider["telegram"]["configured"] is True + assert by_provider["telegram"]["auth_mode"] == "deep_link" + assert by_provider["telegram"]["credential_values"] == { + "bot_token": "********", + "bot_username": "deerflow_bot", + } + assert by_provider["slack"]["configured"] is True + assert by_provider["slack"]["auth_mode"] == "binding_code" + assert by_provider["slack"]["connection_status"] == "not_connected" + assert by_provider["slack"]["credential_values"] == { + "bot_token": "********", + "app_token": "********", + } + assert by_provider["discord"]["configured"] is True + assert by_provider["discord"]["auth_mode"] == "binding_code" + assert by_provider["discord"]["credential_values"] == {"bot_token": "********"} + assert by_provider["feishu"]["configured"] is True + assert by_provider["feishu"]["auth_mode"] == "binding_code" + assert by_provider["feishu"]["connection_status"] == "not_connected" + assert by_provider["feishu"]["credential_values"] == { + "app_id": "feishu-app", + "app_secret": "********", + } + assert by_provider["dingtalk"]["configured"] is True + assert by_provider["dingtalk"]["auth_mode"] == "binding_code" + assert by_provider["dingtalk"]["credential_values"] == { + "client_id": "dingtalk-client", + "client_secret": "********", + } + assert by_provider["wechat"]["configured"] is True + assert by_provider["wechat"]["auth_mode"] == "binding_code" + assert by_provider["wechat"]["credential_values"] == {"bot_token": "********"} + assert by_provider["wecom"]["configured"] is True + assert by_provider["wecom"]["auth_mode"] == "binding_code" + assert by_provider["wecom"]["credential_values"] == { + "bot_id": "wecom-bot", + "bot_secret": "********", + } + + anyio.run(repo.close) + + +def test_get_providers_degrades_when_persistence_is_unavailable(monkeypatch): + monkeypatch.setattr(channel_connections, "get_session_factory", lambda: None) + app = _make_app(_enabled_connections_config(), None, _channels_config()) + + with TestClient(app) as client: + response = client.get("/api/channels/providers") + + assert response.status_code == 200 + by_provider = {item["provider"]: item for item in response.json()["providers"]} + assert by_provider["slack"]["configured"] is True + assert by_provider["slack"]["connectable"] is True + assert by_provider["slack"]["connection_status"] == "not_connected" + + +def test_get_providers_reports_connected_without_binding_in_auth_disabled_mode(tmp_path, monkeypatch): + import anyio + + monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1") + monkeypatch.delenv("DEER_FLOW_ENV", raising=False) + monkeypatch.delenv("ENVIRONMENT", raising=False) + repo = anyio.run(_make_repo, tmp_path) + app = _make_app(_enabled_connections_config(), repo, _channels_config()) + + with TestClient(app) as client: + response = client.get("/api/channels/providers") + + assert response.status_code == 200 + by_provider = {item["provider"]: item for item in response.json()["providers"]} + # Auth-disabled local mode routes channel messages to the default user, so + # a configured running channel is effectively connected without a binding. + assert by_provider["slack"]["connection_status"] == "connected" + assert by_provider["feishu"]["connection_status"] == "connected" + + anyio.run(repo.close) + + +def test_get_providers_reports_unconfigured_when_runtime_channel_is_missing(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + app = _make_app(_enabled_connections_config(), repo, {"telegram": {"enabled": True, "bot_token": "telegram-token"}}) + + with TestClient(app) as client: + response = client.get("/api/channels/providers") + + assert response.status_code == 200 + by_provider = {item["provider"]: item for item in response.json()["providers"]} + assert by_provider["telegram"]["configured"] is True + assert by_provider["slack"]["configured"] is False + assert by_provider["slack"]["connectable"] is False + assert "Slack credentials" in by_provider["slack"]["unavailable_reason"] + assert by_provider["discord"]["configured"] is False + assert "Discord credentials" in by_provider["discord"]["unavailable_reason"] + assert by_provider["feishu"]["configured"] is False + assert "Feishu credentials" in by_provider["feishu"]["unavailable_reason"] + assert by_provider["dingtalk"]["configured"] is False + assert "DingTalk credentials" in by_provider["dingtalk"]["unavailable_reason"] + assert by_provider["wechat"]["configured"] is False + assert "WeChat credentials" in by_provider["wechat"]["unavailable_reason"] + assert by_provider["wecom"]["configured"] is False + assert "WeCom credentials" in by_provider["wecom"]["unavailable_reason"] + + anyio.run(repo.close) + + +def test_get_providers_reports_configured_channel_not_running(tmp_path, monkeypatch): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + app = _make_app(_enabled_connections_config(), repo, _channels_config()) + service = SimpleNamespace( + get_status=lambda: { + "service_running": True, + "channels": { + "feishu": { + "enabled": True, + "running": False, + } + }, + } + ) + monkeypatch.setattr("app.channels.service.get_channel_service", lambda: service) + + with TestClient(app) as client: + response = client.get("/api/channels/providers") + + assert response.status_code == 200 + by_provider = {item["provider"]: item for item in response.json()["providers"]} + assert by_provider["feishu"]["configured"] is True + assert by_provider["feishu"]["connectable"] is False + assert by_provider["feishu"]["connection_status"] == "not_connected" + assert "configured but is not running" in by_provider["feishu"]["unavailable_reason"] + + anyio.run(repo.close) + + +def test_get_providers_restarts_configured_channel_when_service_can_reconcile(tmp_path, monkeypatch): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + config = ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "feishu": {"enabled": True}, + } + ) + channels_config = { + "feishu": { + "enabled": True, + "app_id": "feishu-app", + "app_secret": "feishu-secret", + } + } + app = _make_app(config, repo, channels_config) + status = { + "service_running": True, + "channels": { + "feishu": { + "enabled": True, + "running": False, + } + }, + } + reconciled: list[tuple[str, dict]] = [] + + async def ensure_channel_ready(provider, runtime_config): + reconciled.append((provider, dict(runtime_config))) + status["channels"][provider]["running"] = True + return True + + service = SimpleNamespace( + get_status=lambda: status, + ensure_channel_ready=ensure_channel_ready, + ) + monkeypatch.setattr("app.channels.service.get_channel_service", lambda: service) + + with TestClient(app) as client: + response = client.get("/api/channels/providers") + + assert response.status_code == 200 + by_provider = {item["provider"]: item for item in response.json()["providers"]} + assert by_provider["feishu"]["configured"] is True + assert by_provider["feishu"]["connectable"] is True + assert by_provider["feishu"]["connection_status"] == "not_connected" + assert by_provider["feishu"]["unavailable_reason"] is None + assert reconciled == [("feishu", channels_config["feishu"])] + + anyio.run(repo.close) + + +def test_get_providers_uses_newest_connection_status_per_provider(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + + async def seed_connections(): + await repo.upsert_connection( + owner_user_id=str(_user().id), + provider="slack", + external_account_id="U-old", + workspace_id="T-old", + status="revoked", + ) + await anyio.sleep(0.01) + await repo.upsert_connection( + owner_user_id=str(_user().id), + provider="slack", + external_account_id="U-new", + workspace_id="T-new", + status="connected", + ) + + anyio.run(seed_connections) + app = _make_app(_enabled_connections_config(), repo, _channels_config()) + + with TestClient(app) as client: + response = client.get("/api/channels/providers") + + assert response.status_code == 200 + by_provider = {item["provider"]: item for item in response.json()["providers"]} + assert by_provider["slack"]["connection_status"] == "connected" + + anyio.run(repo.close) + + +def test_get_connections_returns_current_user_connections_only(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + + async def seed_connections(): + await repo.upsert_connection( + owner_user_id=str(_user().id), + provider="telegram", + external_account_id="42", + external_account_name="Alice", + status="connected", + ) + await repo.upsert_connection( + owner_user_id="other-user", + provider="telegram", + external_account_id="99", + external_account_name="Bob", + status="connected", + ) + + anyio.run(seed_connections) + app = _make_app(_enabled_connections_config(), repo, _channels_config()) + + with TestClient(app) as client: + response = client.get("/api/channels/connections") + + assert response.status_code == 200 + body = response.json() + assert len(body["connections"]) == 1 + assert body["connections"][0]["provider"] == "telegram" + assert body["connections"][0]["external_account_id"] == "42" + + anyio.run(repo.close) + + +def test_connect_telegram_returns_deep_link_and_persists_state(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + app = _make_app(_enabled_connections_config(), repo, _channels_config()) + + with TestClient(app) as client: + response = client.post("/api/channels/telegram/connect") + + assert response.status_code == 200 + body = response.json() + assert body["provider"] == "telegram" + assert body["mode"] == "deep_link" + assert body["url"].startswith("https://t.me/deerflow_bot?start=") + assert body["code"] + assert "/start" in body["instruction"] + + async def count_states(): + return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="telegram") + + assert anyio.run(count_states) == 1 + + anyio.run(repo.close) + + +def test_connect_slack_returns_binding_command_and_persists_state(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + app = _make_app(_enabled_connections_config(), repo, _channels_config()) + + with TestClient(app) as client: + response = client.post("/api/channels/slack/connect") + + assert response.status_code == 200 + body = response.json() + assert body["provider"] == "slack" + assert body["mode"] == "binding_code" + assert body["url"] is None + assert len(body["code"]) >= 22 + assert body["instruction"] == f"Send /connect {body['code']} to the DeerFlow Slack bot." + + async def count_states(): + return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="slack") + + assert anyio.run(count_states) == 1 + + anyio.run(repo.close) + + +def test_connect_discord_returns_binding_command_and_persists_state(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + app = _make_app(_enabled_connections_config(), repo, _channels_config()) + + with TestClient(app) as client: + response = client.post("/api/channels/discord/connect") + + assert response.status_code == 200 + body = response.json() + assert body["provider"] == "discord" + assert body["mode"] == "binding_code" + assert body["url"] is None + assert body["code"] + assert body["instruction"] == f"Send /connect {body['code']} to the DeerFlow Discord bot." + + async def count_states(): + return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="discord") + + assert anyio.run(count_states) == 1 + + anyio.run(repo.close) + + +def test_connect_existing_binding_code_channels_return_command_and_persist_state(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + app = _make_app(_enabled_connections_config(), repo, _channels_config()) + + providers = ["feishu", "dingtalk", "wechat", "wecom"] + with TestClient(app) as client: + responses = {provider: client.post(f"/api/channels/{provider}/connect") for provider in providers} + + for provider, response in responses.items(): + expected_display_name = { + "feishu": "Feishu", + "dingtalk": "DingTalk", + "wechat": "WeChat", + "wecom": "WeCom", + }[provider] + assert response.status_code == 200 + body = response.json() + assert body["provider"] == provider + assert body["mode"] == "binding_code" + assert body["url"] is None + assert len(body["code"]) >= 22 + assert body["instruction"] == f"Send /connect {body['code']} to the DeerFlow {expected_display_name} bot." + + async def count_states(provider=provider): + return await repo.count_oauth_states(owner_user_id=str(_user().id), provider=provider) + + assert anyio.run(count_states) == 1 + + anyio.run(repo.close) + + +def test_connect_unconfigured_runtime_channel_returns_400(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + app = _make_app(_enabled_connections_config(), repo, {}) + + with TestClient(app) as client: + response = client.post("/api/channels/slack/connect") + + assert response.status_code == 400 + assert "Slack credentials" in response.json()["detail"] + + anyio.run(repo.close) + + +def test_configure_provider_runtime_credentials_enables_connect_without_file_edits(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + config = ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "slack": {"enabled": True}, + } + ) + app = _make_app(config, repo, {}) + + with TestClient(app) as client: + configure_response = client.post( + "/api/channels/slack/runtime-config", + json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}}, + ) + connect_response = client.post("/api/channels/slack/connect") + + assert configure_response.status_code == 200 + configured = configure_response.json() + assert configured["provider"] == "slack" + assert configured["configured"] is True + assert configured["connectable"] is True + assert configured["connection_status"] == "not_connected" + assert app.state.channels_config["slack"] == { + "enabled": True, + "bot_token": "xoxb-ui", + "app_token": "xapp-ui", + } + assert connect_response.status_code == 200 + assert connect_response.json()["provider"] == "slack" + + anyio.run(repo.close) + + +def test_runtime_config_endpoints_require_admin(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + config = ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "slack": {"enabled": True}, + } + ) + app = make_authed_test_app(user_factory=_non_admin_user) + app.state.channel_connections_config = config + app.state.channel_connection_repo = repo + app.state.channels_config = {} + runtime_config_dir = TemporaryDirectory() + app.state.channel_runtime_config_tmpdir = runtime_config_dir + app.state.channel_runtime_config_store = ChannelRuntimeConfigStore(f"{runtime_config_dir.name}/runtime-config.json") + app.include_router(channel_connections.router) + + with TestClient(app) as client: + configure_response = client.post( + "/api/channels/slack/runtime-config", + json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}}, + ) + disconnect_response = client.delete("/api/channels/slack/runtime-config") + providers_response = client.get("/api/channels/providers") + + assert configure_response.status_code == 403 + assert "Admin privileges" in configure_response.json()["detail"] + assert disconnect_response.status_code == 403 + # Read-only provider listing stays available to regular users. + assert providers_response.status_code == 200 + + anyio.run(repo.close) + + +def test_configure_telegram_runtime_uses_new_bot_username_for_deep_link_without_mutating_config(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + config = ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "telegram": {"enabled": True, "bot_username": "old_bot"}, + } + ) + app = _make_app(config, repo, {}) + + with TestClient(app) as client: + configure_response = client.post( + "/api/channels/telegram/runtime-config", + json={"values": {"bot_token": "tg-token", "bot_username": "new_bot"}}, + ) + connect_response = client.post("/api/channels/telegram/connect") + + assert configure_response.status_code == 200 + assert configure_response.json()["credential_values"]["bot_username"] == "new_bot" + assert connect_response.status_code == 200 + assert connect_response.json()["url"].startswith("https://t.me/new_bot?start=") + # The original config object cached by get_app_config() must stay untouched. + assert config.telegram.bot_username == "old_bot" + + anyio.run(repo.close) + + +def test_configure_provider_runtime_credentials_survive_local_restart(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + config = ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "slack": {"enabled": True}, + } + ) + runtime_config_path = tmp_path / "channels" / "runtime-config.json" + first_app = _make_app( + config, + repo, + {}, + runtime_config_store=ChannelRuntimeConfigStore(runtime_config_path), + ) + + with TestClient(first_app) as client: + configure_response = client.post( + "/api/channels/slack/runtime-config", + json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}}, + ) + + assert configure_response.status_code == 200 + + restarted_app = _make_app( + config, + repo, + runtime_config_store=ChannelRuntimeConfigStore(runtime_config_path), + set_channels_config_state=False, + ) + + with TestClient(restarted_app) as client: + response = client.get("/api/channels/providers") + + assert response.status_code == 200 + by_provider = {item["provider"]: item for item in response.json()["providers"]} + assert by_provider["slack"]["configured"] is True + assert by_provider["slack"]["connectable"] is True + assert by_provider["slack"]["connection_status"] == "not_connected" + assert restarted_app.state.channels_config["slack"] == { + "enabled": True, + "bot_token": "xoxb-ui", + "app_token": "xapp-ui", + } + + anyio.run(repo.close) + + +def test_configure_provider_runtime_credentials_preserves_masked_secrets(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + config = ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "feishu": {"enabled": True}, + } + ) + runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json") + app = _make_app( + config, + repo, + { + "feishu": { + "enabled": True, + "app_id": "old-app-id", + "app_secret": "old-secret", + } + }, + runtime_config_store=runtime_config_store, + ) + + with TestClient(app) as client: + configure_response = client.post( + "/api/channels/feishu/runtime-config", + json={ + "values": { + "app_id": "new-app-id", + "app_secret": "********", + } + }, + ) + providers_response = client.get("/api/channels/providers") + + assert configure_response.status_code == 200 + assert app.state.channels_config["feishu"] == { + "enabled": True, + "app_id": "new-app-id", + "app_secret": "old-secret", + } + assert runtime_config_store.get_provider_config("feishu") == { + "enabled": True, + "app_id": "new-app-id", + "app_secret": "old-secret", + } + by_provider = {item["provider"]: item for item in providers_response.json()["providers"]} + assert by_provider["feishu"]["credential_values"] == { + "app_id": "new-app-id", + "app_secret": "********", + } + + anyio.run(repo.close) + + +def test_disconnect_provider_runtime_config_clears_connected_state(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + config = ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "slack": {"enabled": True}, + } + ) + runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json") + app = _make_app(config, repo, {}, runtime_config_store=runtime_config_store) + + with TestClient(app) as client: + configure_response = client.post( + "/api/channels/slack/runtime-config", + json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}}, + ) + disconnect_response = client.delete("/api/channels/slack/runtime-config") + providers_response = client.get("/api/channels/providers") + + assert configure_response.status_code == 200 + assert disconnect_response.status_code == 200 + disconnected = disconnect_response.json() + assert disconnected["provider"] == "slack" + assert disconnected["configured"] is False + assert disconnected["connectable"] is False + assert disconnected["connection_status"] == "not_connected" + assert runtime_config_store.get_provider_config("slack") == { + "enabled": False, + "_runtime_disabled": True, + } + + assert providers_response.status_code == 200 + by_provider = {item["provider"]: item for item in providers_response.json()["providers"]} + assert by_provider["slack"]["connection_status"] == "not_connected" + + anyio.run(repo.close) + + +def test_disconnect_provider_runtime_config_suppresses_file_config_and_stops_channel(tmp_path, monkeypatch): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + config = ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "feishu": {"enabled": True}, + } + ) + set_app_config( + AppConfig.model_validate( + { + "sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}, + "channels": { + "feishu": { + "enabled": True, + "app_id": "file-app-id", + "app_secret": "file-secret", + } + }, + } + ) + ) + runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json") + runtime_config_store.set_provider_config( + "feishu", + { + "enabled": True, + "app_id": "runtime-app-id", + "app_secret": "runtime-secret", + }, + ) + service = SimpleNamespace( + configure_channel=AsyncMock(return_value=True), + remove_channel=AsyncMock(return_value=True), + ) + monkeypatch.setattr("app.channels.service.get_channel_service", lambda: service) + app = _make_app( + config, + repo, + { + "feishu": { + "enabled": True, + "app_id": "runtime-app-id", + "app_secret": "runtime-secret", + } + }, + runtime_config_store=runtime_config_store, + ) + + with TestClient(app) as client: + disconnect_response = client.delete("/api/channels/feishu/runtime-config") + providers_response = client.get("/api/channels/providers") + + assert disconnect_response.status_code == 200 + disconnected = disconnect_response.json() + assert disconnected["provider"] == "feishu" + assert disconnected["configured"] is False + assert disconnected["connectable"] is False + assert disconnected["connection_status"] == "not_connected" + assert "feishu" not in app.state.channels_config + service.remove_channel.assert_awaited_once_with("feishu") + service.configure_channel.assert_not_awaited() + + assert providers_response.status_code == 200 + by_provider = {item["provider"]: item for item in providers_response.json()["providers"]} + assert by_provider["feishu"]["configured"] is False + assert by_provider["feishu"]["connection_status"] == "not_connected" + + anyio.run(repo.close) + + +def test_disconnect_provider_runtime_config_revokes_current_user_provider_connections(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + + async def seed_connection(): + await repo.upsert_connection( + owner_user_id=str(_user().id), + provider="slack", + external_account_id="U123", + status="connected", + ) + + anyio.run(seed_connection) + config = ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "slack": {"enabled": True}, + } + ) + runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json") + app = _make_app(config, repo, {}, runtime_config_store=runtime_config_store) + + with TestClient(app) as client: + configure_response = client.post( + "/api/channels/slack/runtime-config", + json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}}, + ) + disconnect_response = client.delete("/api/channels/slack/runtime-config") + + assert configure_response.status_code == 200 + assert disconnect_response.status_code == 200 + + async def get_connection_status(): + return (await repo.list_connections(str(_user().id)))[0]["status"] + + assert anyio.run(get_connection_status) == "revoked" + + anyio.run(repo.close) + + +def test_disconnect_connection_revokes_current_user_connection(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + + async def seed_connection(): + connection = await repo.upsert_connection( + owner_user_id=str(_user().id), + provider="telegram", + external_account_id="42", + status="connected", + ) + return connection["id"] + + connection_id = anyio.run(seed_connection) + app = _make_app(_enabled_connections_config(), repo, _channels_config()) + + with TestClient(app) as client: + response = client.delete(f"/api/channels/connections/{connection_id}") + + assert response.status_code == 204 + + async def get_connection_status(): + return (await repo.list_connections(str(_user().id)))[0]["status"] + + assert anyio.run(get_connection_status) == "revoked" + + anyio.run(repo.close) + + +def test_disconnect_connection_is_current_user_scoped(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + + async def seed_connection(): + connection = await repo.upsert_connection( + owner_user_id="other-user", + provider="telegram", + external_account_id="42", + status="connected", + ) + return connection["id"] + + connection_id = anyio.run(seed_connection) + app = _make_app(_enabled_connections_config(), repo, _channels_config()) + + with TestClient(app) as client: + response = client.delete(f"/api/channels/connections/{connection_id}") + + assert response.status_code == 404 + + async def get_connection_status(): + return (await repo.list_connections("other-user"))[0]["status"] + + assert anyio.run(get_connection_status) == "connected" + + anyio.run(repo.close) diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index 40eecc529..febdc11e5 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -487,6 +487,7 @@ def _make_mock_langgraph_client(thread_id="test-thread-123", run_result=None): # threads.create() returns a Thread-like dict mock_client.threads.create = AsyncMock(return_value={"thread_id": thread_id}) + mock_client.threads.update = AsyncMock(return_value={"thread_id": thread_id}) # threads.get() returns thread info (succeeds by default) mock_client.threads.get = AsyncMock(return_value={"thread_id": thread_id}) @@ -504,6 +505,17 @@ def _make_mock_langgraph_client(thread_id="test-thread-123", run_result=None): return mock_client +async def _make_channel_connection_repo(tmp_path: Path): + from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher + from deerflow.persistence.engine import get_session_factory, init_engine + + await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'channel-connections.db'}", sqlite_dir=str(tmp_path)) + return ChannelConnectionRepository( + get_session_factory(), + cipher=ChannelCredentialCipher.from_key("test-channel-key"), + ) + + def _make_stream_part(event: str, data): return SimpleNamespace(event=event, data=data) @@ -656,16 +668,34 @@ class TestChannelManager: await manager.start() - inbound = InboundMessage(channel_name="test", chat_id="chat1", user_id="user1", text="hi") + inbound = InboundMessage( + channel_name="test", + chat_id="chat1", + user_id="user1", + text="hi", + topic_id="topic1", + thread_ts="msg1", + connection_id="conn1", + ) await bus.publish_inbound(inbound) await _wait_for(lambda: len(outbound_received) >= 1) await manager.stop() # Thread should be created through Gateway mock_client.threads.create.assert_called_once() + assert mock_client.threads.create.call_args.kwargs["metadata"] == { + "channel_source": { + "type": "im_channel", + "provider": "test", + "chat_id": "chat1", + "topic_id": "topic1", + "thread_ts": "msg1", + "connection_id": "conn1", + } + } # Thread ID should be stored - thread_id = store.get_thread_id("test", "chat1") + thread_id = store.get_thread_id("test", "chat1", topic_id="topic1") assert thread_id == "test-thread-123" # runs.wait should be called with the thread_id @@ -883,10 +913,12 @@ class TestChannelManager: _run(go()) - def test_clarification_follow_up_preserves_history(self): + def test_clarification_follow_up_preserves_history(self, monkeypatch): """Conversation should continue after ask_clarification instead of resetting history.""" from app.channels.manager import ChannelManager + monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False) + async def go(): bus = MessageBus() store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") @@ -1954,10 +1986,12 @@ class TestChannelManager: _run(go()) - def test_same_topic_reuses_thread(self): + def test_same_topic_reuses_thread(self, monkeypatch): """Messages with the same topic_id should reuse the same DeerFlow thread.""" from app.channels.manager import ChannelManager + monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False) + async def go(): bus = MessageBus() store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") @@ -1990,6 +2024,17 @@ class TestChannelManager: # threads.create should be called only ONCE (second message reuses the thread) mock_client.threads.create.assert_called_once() + mock_client.threads.update.assert_called_once_with( + "topic-thread-1", + metadata={ + "channel_source": { + "type": "im_channel", + "provider": "test", + "chat_id": "chat1", + "topic_id": "topic-root-123", + } + }, + ) # Both runs.wait calls should use the same thread_id assert mock_client.runs.wait.call_count == 2 @@ -2325,8 +2370,9 @@ class TestResolveRunParamsUserId: store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") return ChannelManager(bus=bus, store=store) - def test_safe_user_id_is_passed_through(self): + def test_safe_user_id_is_passed_through(self, monkeypatch): manager = self._manager() + monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False) msg = InboundMessage(channel_name="telegram", chat_id="c", user_id="123456", text="hi") _, _, run_context = manager._resolve_run_params(msg, "thread-1") @@ -2334,10 +2380,78 @@ class TestResolveRunParamsUserId: assert run_context["user_id"] == "123456" assert run_context["channel_user_id"] == "123456" - def test_unsafe_user_id_is_normalized_but_raw_preserved(self): + def test_connection_owner_user_id_takes_precedence_over_platform_user_id(self, monkeypatch): + manager = self._manager() + monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False) + msg = InboundMessage( + channel_name="slack", + chat_id="C123", + user_id="U-platform", + owner_user_id="deerflow-user-1", + connection_id="connection-1", + text="hi", + ) + + _, _, run_context = manager._resolve_run_params(msg, "thread-1") + + assert run_context["user_id"] == "deerflow-user-1" + assert run_context["channel_user_id"] == "U-platform" + + def test_auth_disabled_user_id_is_used_for_unbound_channel_messages(self, monkeypatch): + from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID + from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME + + manager = self._manager() + monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1") + msg = InboundMessage(channel_name="slack", chat_id="C123", user_id="U-platform", text="hi") + + _, _, run_context = manager._resolve_run_params(msg, "thread-1") + + assert run_context["user_id"] == AUTH_DISABLED_USER_ID + assert run_context["channel_user_id"] == "U-platform" + + from app.channels.manager import _owner_headers + + headers = _owner_headers(msg) + assert headers is not None + assert headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == AUTH_DISABLED_USER_ID + + def test_auth_disabled_user_id_overrides_bound_owner_for_local_visibility(self, monkeypatch): + from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID + + manager = self._manager() + monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1") + msg = InboundMessage( + channel_name="slack", + chat_id="C123", + user_id="U-platform", + owner_user_id="real-user-from-old-binding", + text="hi", + ) + + _, _, run_context = manager._resolve_run_params(msg, "thread-1") + + assert run_context["user_id"] == AUTH_DISABLED_USER_ID + assert run_context["channel_user_id"] == "U-platform" + + def test_unbound_channel_messages_keep_platform_user_id_when_auth_is_enabled(self, monkeypatch): + from app.channels.manager import _owner_headers + + manager = self._manager() + monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False) + msg = InboundMessage(channel_name="slack", chat_id="C123", user_id="U-platform", text="hi") + + _, _, run_context = manager._resolve_run_params(msg, "thread-1") + + assert run_context["user_id"] == "U-platform" + assert run_context["channel_user_id"] == "U-platform" + assert _owner_headers(msg) is None + + def test_unsafe_user_id_is_normalized_but_raw_preserved(self, monkeypatch): from deerflow.config.paths import make_safe_user_id manager = self._manager() + monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False) raw = "user@example.com" msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi") @@ -2347,9 +2461,32 @@ class TestResolveRunParamsUserId: assert run_context["user_id"] != raw assert run_context["channel_user_id"] == raw - @pytest.mark.parametrize("raw_user_id", ["", None]) - def test_empty_or_none_user_id_is_not_injected(self, raw_user_id): + def test_unsafe_user_id_migrates_unique_legacy_bucket(self, tmp_path, monkeypatch): + from deerflow.config.paths import Paths, make_safe_user_id + + paths = Paths(tmp_path) + legacy_dir = paths.base_dir / "users" / "user-example-com-63a710569261a24b" + legacy_dir.mkdir(parents=True) + (legacy_dir / "memory.json").write_text('{"legacy": true}\n', encoding="utf-8") + monkeypatch.setattr("deerflow.config.paths.get_paths", lambda: paths) + manager = self._manager() + monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False) + raw = "user@example.com" + msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi") + + _, _, run_context = manager._resolve_run_params(msg, "thread-1") + + safe = make_safe_user_id(raw) + assert run_context["user_id"] == safe + assert paths.user_dir(safe).exists() + assert not legacy_dir.exists() + assert (paths.user_dir(safe) / "memory.json").read_text(encoding="utf-8") == '{"legacy": true}\n' + + @pytest.mark.parametrize("raw_user_id", ["", None]) + def test_empty_or_none_user_id_is_not_injected(self, raw_user_id, monkeypatch): + manager = self._manager() + monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False) msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw_user_id, text="hi") _, _, run_context = manager._resolve_run_params(msg, "thread-1") @@ -2358,6 +2495,93 @@ class TestResolveRunParamsUserId: assert "channel_user_id" not in run_context +class TestChannelManagerConnectionRouting: + def test_connection_scoped_conversations_do_not_share_threads(self, tmp_path, monkeypatch): + from app.channels.manager import ChannelManager + from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME + from deerflow.persistence.engine import close_engine + + monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False) + + async def go(): + repo = await _make_channel_connection_repo(tmp_path) + alice = await repo.upsert_connection( + owner_user_id="alice", + provider="slack", + external_account_id="U-alice", + workspace_id="T1", + ) + bob = await repo.upsert_connection( + owner_user_id="bob", + provider="slack", + external_account_id="U-bob", + workspace_id="T1", + ) + + bus = MessageBus() + store = ChannelStore(path=tmp_path / "legacy-store.json") + manager = ChannelManager(bus=bus, store=store, connection_repo=repo) + mock_client = _make_mock_langgraph_client() + mock_client.threads.create = AsyncMock( + side_effect=[ + {"thread_id": "thread-alice"}, + {"thread_id": "thread-bob"}, + ] + ) + manager._client = mock_client + + await manager._handle_chat( + InboundMessage( + channel_name="slack", + chat_id="C-shared", + user_id="U-alice", + owner_user_id="alice", + connection_id=alice["id"], + text="hello", + thread_ts="1710000000.000100", + topic_id="1710000000.000100", + ) + ) + await manager._handle_chat( + InboundMessage( + channel_name="slack", + chat_id="C-shared", + user_id="U-bob", + owner_user_id="bob", + connection_id=bob["id"], + text="hello", + thread_ts="1710000000.000100", + topic_id="1710000000.000100", + ) + ) + + assert await repo.get_thread_id(alice["id"], "C-shared", "1710000000.000100") == "thread-alice" + assert await repo.get_thread_id(bob["id"], "C-shared", "1710000000.000100") == "thread-bob" + assert store.list_entries() == [] + + first_context = mock_client.runs.wait.call_args_list[0].kwargs["context"] + second_context = mock_client.runs.wait.call_args_list[1].kwargs["context"] + assert first_context["user_id"] == "alice" + assert first_context["channel_user_id"] == "U-alice" + assert second_context["user_id"] == "bob" + assert second_context["channel_user_id"] == "U-bob" + + first_create_headers = mock_client.threads.create.call_args_list[0].kwargs["headers"] + second_create_headers = mock_client.threads.create.call_args_list[1].kwargs["headers"] + assert first_create_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "alice" + assert second_create_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "bob" + + first_run_headers = mock_client.runs.wait.call_args_list[0].kwargs["headers"] + second_run_headers = mock_client.runs.wait.call_args_list[1].kwargs["headers"] + assert first_run_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "alice" + assert second_run_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "bob" + + try: + _run(go()) + finally: + _run(close_engine()) + + # --------------------------------------------------------------------------- # ChannelService tests # --------------------------------------------------------------------------- @@ -3108,6 +3332,38 @@ class TestChannelService: _run(go()) + def test_concurrent_ensure_channel_ready_starts_channel_once(self): + from app.channels.service import ChannelService + + async def go(): + service = ChannelService( + channels_config={ + "telegram": {"enabled": True, "bot_token": "tg-token"}, + } + ) + await service.manager.start() + service._running = True + start_calls = [] + + async def fake_start_channel(name, config): + start_calls.append(name) + await asyncio.sleep(0.01) + service._channels[name] = SimpleNamespace(is_running=True, stop=AsyncMock()) + return True + + service._start_channel = fake_start_channel + + results = await asyncio.gather( + service.ensure_channel_ready("telegram"), + service.ensure_channel_ready("telegram"), + ) + + assert results == [True, True] + assert start_calls == ["telegram"] + await service.stop() + + _run(go()) + def test_session_config_is_forwarded_to_manager(self): from app.channels.service import ChannelService @@ -3175,6 +3431,226 @@ class TestChannelService: assert service._config == {"telegram": {"enabled": False}} + def test_from_app_config_does_not_create_runtime_channels_from_channel_connections( + self, + monkeypatch, + tmp_path, + ): + from app.channels.service import ChannelService + from deerflow.config import paths as paths_module + from deerflow.config.channel_connections_config import ChannelConnectionsConfig + + monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path)) + monkeypatch.setattr(paths_module, "_paths", None) + app_config = SimpleNamespace( + model_extra={}, + channel_connections=ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "telegram": {"enabled": True, "bot_username": "deerflow_bot"}, + "slack": {"enabled": True}, + "discord": {"enabled": True}, + } + ), + ) + + service = ChannelService.from_app_config(app_config) + + assert service._config == {} + + def test_from_app_config_preserves_existing_runtime_channels_with_channel_connections_enabled( + self, + monkeypatch, + tmp_path, + ): + from app.channels.runtime_config_store import ChannelRuntimeConfigStore + from app.channels.service import ChannelService + from deerflow.config import paths as paths_module + from deerflow.config.channel_connections_config import ChannelConnectionsConfig + + monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path)) + monkeypatch.setattr(paths_module, "_paths", None) + ChannelRuntimeConfigStore().set_provider_config( + "slack", + { + "enabled": True, + "bot_token": "xoxb-ui", + "app_token": "xapp-ui", + }, + ) + app_config = SimpleNamespace( + model_extra={ + "channels": { + "telegram": {"enabled": True, "bot_token": "telegram-token"}, + "slack": {"enabled": True, "bot_token": "xoxb", "app_token": "xapp"}, + "discord": {"enabled": True, "bot_token": "discord-bot-token"}, + } + }, + channel_connections=ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "telegram": {"enabled": True, "bot_username": "deerflow_bot"}, + "slack": {"enabled": True}, + "discord": {"enabled": True}, + } + ), + ) + + service = ChannelService.from_app_config(app_config) + + assert service._config["telegram"]["bot_token"] == "telegram-token" + assert service._config["slack"]["app_token"] == "xapp" + assert service._config["discord"]["bot_token"] == "discord-bot-token" + + def test_from_app_config_loads_persisted_runtime_channel_config(self, monkeypatch, tmp_path): + from app.channels.runtime_config_store import ChannelRuntimeConfigStore + from app.channels.service import ChannelService + from deerflow.config import paths as paths_module + from deerflow.config.channel_connections_config import ChannelConnectionsConfig + + monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path)) + monkeypatch.setattr(paths_module, "_paths", None) + ChannelRuntimeConfigStore().set_provider_config( + "slack", + { + "enabled": True, + "bot_token": "xoxb-ui", + "app_token": "xapp-ui", + }, + ) + app_config = SimpleNamespace( + model_extra={}, + channel_connections=ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "slack": {"enabled": True}, + } + ), + ) + + service = ChannelService.from_app_config(app_config) + + assert service._config["slack"] == { + "enabled": True, + "bot_token": "xoxb-ui", + "app_token": "xapp-ui", + } + + def test_from_app_config_runtime_disconnect_suppresses_file_channel_config(self, monkeypatch, tmp_path): + from app.channels.runtime_config_store import ChannelRuntimeConfigStore + from app.channels.service import ChannelService + from deerflow.config import paths as paths_module + from deerflow.config.channel_connections_config import ChannelConnectionsConfig + + monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path)) + monkeypatch.setattr(paths_module, "_paths", None) + ChannelRuntimeConfigStore().set_provider_config( + "feishu", + { + "enabled": False, + "_runtime_disabled": True, + }, + ) + app_config = SimpleNamespace( + model_extra={ + "channels": { + "feishu": { + "enabled": True, + "app_id": "file-app-id", + "app_secret": "file-secret", + } + } + }, + channel_connections=ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "feishu": {"enabled": True}, + } + ), + ) + + service = ChannelService.from_app_config(app_config) + + assert "feishu" not in service._config + + def test_start_retries_configured_channel_until_ready(self, monkeypatch): + from app.channels.service import ChannelService + + class FlakyReadyChannel(Channel): + starts = 0 + + def __init__(self, bus, config): + super().__init__(name="slack", bus=bus, config=config) + + async def start(self): + type(self).starts += 1 + self._running = type(self).starts >= 2 + + async def stop(self): + self._running = False + + async def send(self, msg): + return None + + monkeypatch.setattr( + "deerflow.reflection.resolve_class", + lambda import_path, base_class=None: FlakyReadyChannel, + ) + + async def go(): + service = ChannelService( + channels_config={ + "slack": { + "enabled": True, + "bot_token": "xoxb-ui", + "app_token": "xapp-ui", + }, + } + ) + + try: + await service.start() + + assert FlakyReadyChannel.starts == 2 + assert service.get_status()["channels"]["slack"]["running"] is True + finally: + await service.stop() + + _run(go()) + + def test_connection_repo_is_forwarded_to_manager(self): + from app.channels.service import ChannelService + + repo = object() + service = ChannelService(channels_config={}, connection_repo=repo) + + assert service.manager._connection_repo is repo + + def test_remove_channel_stops_running_channel_and_forgets_config(self): + from app.channels.service import ChannelService + + async def go(): + service = ChannelService( + channels_config={ + "slack": { + "enabled": True, + "bot_token": "xoxb-ui", + "app_token": "xapp-ui", + }, + } + ) + channel = AsyncMock() + service._channels["slack"] = channel + service._running = True + + assert await service.remove_channel("slack") is True + + channel.stop.assert_awaited_once() + assert "slack" not in service._channels + assert "slack" not in service._config + + _run(go()) + def test_disabled_channel_with_string_creds_emits_warning(self, caplog): """Warning is emitted when a channel has string credentials but enabled=false.""" import logging @@ -3192,7 +3668,8 @@ class TestChannelService: await service.stop() _run(go()) - assert any("wecom" in r.message and r.levelno == logging.WARNING for r in caplog.records) + assert any("credentials configured but is disabled" in r.message and r.levelno == logging.WARNING for r in caplog.records) + assert all("wecom" not in r.message for r in caplog.records) def test_disabled_channel_with_int_creds_emits_warning(self, caplog): """Warning is emitted even when YAML-parsed integer credentials are present.""" @@ -3212,7 +3689,8 @@ class TestChannelService: await service.stop() _run(go()) - assert any("telegram" in r.message and r.levelno == logging.WARNING for r in caplog.records) + assert any("credentials configured but is disabled" in r.message and r.levelno == logging.WARNING for r in caplog.records) + assert all("telegram" not in r.message for r in caplog.records) def test_disabled_channel_without_creds_emits_info(self, caplog): """Only an info log (no warning) is emitted when a channel is disabled with no credentials.""" @@ -3267,6 +3745,83 @@ class TestChannelService: assert started_configs["feishu"]["app_secret"] == "new_secret" assert service._config["feishu"]["app_id"] == "new_id" + def test_configure_channel_keeps_explicit_config_over_stale_file_entry(self, monkeypatch): + """UI-entered runtime credentials must not be clobbered by a config.yaml reload. + + configure_channel() receives the authoritative config (e.g. from the + browser Connect/Modify dialog, never written to config.yaml), so its + restart must skip the file reload that restart_channel() performs for + operator-triggered restarts. + """ + from app.channels.service import ChannelService + + stale_file_config = {"feishu": {"enabled": True, "app_id": "file_id", "app_secret": "file_secret"}} + + def mock_get_app_config(): + return SimpleNamespace(model_extra={"channels": stale_file_config}) + + monkeypatch.setattr("deerflow.config.app_config.get_app_config", mock_get_app_config) + + service = ChannelService(channels_config={}) + service._running = True + + started_configs = {} + + async def mock_start_channel(name, config): + started_configs[name] = config + return True + + service._start_channel = mock_start_channel + + async def go(): + await service.configure_channel("feishu", {"enabled": True, "app_id": "ui_id", "app_secret": "ui_secret"}) + + _run(go()) + + assert started_configs["feishu"]["app_id"] == "ui_id" + assert started_configs["feishu"]["app_secret"] == "ui_secret" + assert service._config["feishu"]["app_id"] == "ui_id" + + def test_restart_channel_reload_applies_runtime_store_overlay(self, monkeypatch, tmp_path): + """An operator-triggered restart keeps UI runtime-store credentials for + channels that have no config.yaml entry.""" + from app.channels.runtime_config_store import ChannelRuntimeConfigStore + from app.channels.service import ChannelService + from deerflow.config import paths as paths_module + from deerflow.config.channel_connections_config import ChannelConnectionsConfig + + monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path)) + monkeypatch.setattr(paths_module, "_paths", None) + ChannelRuntimeConfigStore().set_provider_config( + "telegram", + {"enabled": True, "bot_token": "store-token"}, + ) + + def mock_get_app_config(): + return SimpleNamespace( + model_extra={"channels": {}}, + channel_connections=ChannelConnectionsConfig.model_validate({"enabled": True, "telegram": {"enabled": True, "bot_username": "deerflow_bot"}}), + ) + + monkeypatch.setattr("deerflow.config.app_config.get_app_config", mock_get_app_config) + + service = ChannelService(channels_config={}) + + started_configs = {} + + async def mock_start_channel(name, config): + started_configs[name] = config + return True + + service._start_channel = mock_start_channel + + async def go(): + await service.restart_channel("telegram") + + _run(go()) + + assert started_configs["telegram"]["bot_token"] == "store-token" + def test_restart_channel_falls_back_to_cached_config_on_error(self, monkeypatch): """When get_app_config() fails, restart_channel uses cached config.""" from app.channels.service import ChannelService diff --git a/backend/tests/test_csrf_middleware.py b/backend/tests/test_csrf_middleware.py index 28a65c8d7..94dd8db38 100644 --- a/backend/tests/test_csrf_middleware.py +++ b/backend/tests/test_csrf_middleware.py @@ -233,3 +233,15 @@ def test_non_auth_mutation_rejects_mismatched_double_submit_token(): assert response.status_code == 403 assert response.json()["detail"] == "CSRF token mismatch." + + +def test_channel_posts_require_double_submit_csrf(): + client = TestClient(_make_app(), base_url="https://deerflow.example") + + response = client.post( + "/api/channels/slack/connect", + headers={"Origin": "https://deerflow.example"}, + ) + + assert response.status_code == 403 + assert response.json()["detail"] == "CSRF token missing. Include X-CSRF-Token header." diff --git a/backend/tests/test_discord_channel_connections.py b/backend/tests/test_discord_channel_connections.py new file mode 100644 index 000000000..7dc7a7ce1 --- /dev/null +++ b/backend/tests/test_discord_channel_connections.py @@ -0,0 +1,88 @@ +"""Discord connection routing tests.""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.channels.discord import DiscordChannel +from app.channels.message_bus import InboundMessage, MessageBus + + +@pytest.fixture +async def repo(tmp_path): + from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher + from deerflow.persistence.engine import close_engine, get_session_factory, init_engine + + await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'discord.db'}", sqlite_dir=str(tmp_path)) + try: + yield ChannelConnectionRepository( + get_session_factory(), + cipher=ChannelCredentialCipher.from_key("discord-secret"), + ) + finally: + await close_engine() + + +@pytest.mark.anyio +async def test_discord_inbound_attaches_owner_identity_from_user_level_connection(repo): + connection = await repo.upsert_connection( + owner_user_id="alice", + provider="discord", + external_account_id="987", + external_account_name="Alice", + status="connected", + ) + channel = DiscordChannel( + bus=MessageBus(), + config={"bot_token": "discord-bot", "connection_repo": repo}, + ) + inbound = InboundMessage( + channel_name="discord", + chat_id="C123", + user_id="987", + text="hello", + ) + + attached = await channel._attach_connection_identity(inbound, guild_id="G123") + + assert attached.connection_id == connection["id"] + assert attached.owner_user_id == "alice" + assert attached.workspace_id is None + + +@pytest.mark.anyio +async def test_discord_connect_command_binds_gateway_identity(repo): + state = "discord-bind-code" + await repo.create_oauth_state( + owner_user_id="deerflow-user-1", + provider="discord", + state=state, + expires_at=datetime.now(UTC) + timedelta(minutes=5), + ) + channel = DiscordChannel( + bus=MessageBus(), + config={"bot_token": "discord-bot", "connection_repo": repo}, + ) + message = MagicMock() + message.author.id = 987 + message.author.display_name = "Alice" + message.guild.id = 123 + message.guild.name = "Deer Guild" + message.channel.id = 456 + message.channel.send = AsyncMock() + + handled = await channel._bind_connection_from_connect_code(message, state) + + connections = await repo.list_connections("deerflow-user-1") + assert handled is True + assert len(connections) == 1 + assert connections[0]["provider"] == "discord" + assert connections[0]["external_account_id"] == "987" + assert connections[0]["external_account_name"] == "Alice" + assert connections[0]["workspace_id"] == "123" + assert connections[0]["workspace_name"] == "Deer Guild" + assert connections[0]["metadata"]["channel_id"] == "456" + message.channel.send.assert_awaited_once() diff --git a/backend/tests/test_feishu_parser.py b/backend/tests/test_feishu_parser.py index 5ecfb9e0b..34afb107d 100644 --- a/backend/tests/test_feishu_parser.py +++ b/backend/tests/test_feishu_parser.py @@ -73,6 +73,31 @@ def test_feishu_on_message_plain_text(): assert mock_make_inbound.call_args[1]["text"] == "Hello world" +def test_feishu_is_not_running_when_ws_thread_exits(): + bus = MessageBus() + channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"}) + channel._running = True + channel._thread = MagicMock() + channel._thread.is_alive.return_value = False + + assert channel.is_running is False + + +def test_feishu_event_handler_ignores_non_content_message_events(): + import lark_oapi as lark + + bus = MessageBus() + channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"}) + + event_handler = channel._build_event_handler(lark) + + assert "p2.im.message.receive_v1" in event_handler._processorMap + assert "p2.im.message.message_read_v1" in event_handler._processorMap + assert "p2.im.message.reaction.created_v1" in event_handler._processorMap + assert "p2.im.message.reaction.deleted_v1" in event_handler._processorMap + assert "p2.im.message.recalled_v1" in event_handler._processorMap + + def test_feishu_on_message_rich_text(): bus = MessageBus() config = {"app_id": "test", "app_secret": "test"} diff --git a/backend/tests/test_gateway_services.py b/backend/tests/test_gateway_services.py index d62ed9371..41e59ed3b 100644 --- a/backend/tests/test_gateway_services.py +++ b/backend/tests/test_gateway_services.py @@ -4,6 +4,18 @@ from __future__ import annotations import json +import pytest + +from deerflow.config.app_config import AppConfig, reset_app_config, set_app_config + + +@pytest.fixture +def _stub_app_config(): + """Keep run-context tests independent from a developer-local config.yaml.""" + set_app_config(AppConfig.model_validate({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}})) + yield + reset_app_config() + def test_format_sse_basic(): from app.gateway.services import format_sse @@ -36,6 +48,12 @@ def test_format_sse_no_event_id(): assert "id:" not in frame +def test_sanitize_log_param_strips_control_characters(): + from app.gateway.utils import sanitize_log_param + + assert sanitize_log_param("thread\nid\rwith\x00controls") == "threadidwithcontrols" + + def test_normalize_stream_modes_none(): from app.gateway.services import normalize_stream_modes @@ -474,6 +492,83 @@ def test_inject_authenticated_user_context_skips_internal_role(): assert config["context"]["user_id"] == "channel-user-7" +def test_start_run_uses_internal_owner_header_for_persistence(_stub_app_config): + import asyncio + from types import SimpleNamespace + from unittest.mock import patch + + from langgraph.checkpoint.memory import InMemorySaver + from langgraph.store.memory import InMemoryStore + + from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE + from app.gateway.services import start_run + from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore + from deerflow.runtime import RunManager + from deerflow.runtime.runs.store.memory import MemoryRunStore + from deerflow.runtime.user_context import get_effective_user_id + + async def _scenario(): + run_store = MemoryRunStore() + thread_store = MemoryThreadMetaStore(InMemoryStore()) + await thread_store.create("channel-thread", user_id="default", metadata={"legacy": True}) + run_manager = RunManager(store=run_store) + state = SimpleNamespace( + stream_bridge=SimpleNamespace(), + run_manager=run_manager, + checkpointer=InMemorySaver(), + store=InMemoryStore(), + run_event_store=SimpleNamespace(), + run_events_config=None, + thread_store=thread_store, + ) + request = SimpleNamespace( + headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "owner-1"}, + state=SimpleNamespace(user=SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)), + app=SimpleNamespace(state=state), + ) + body = SimpleNamespace( + assistant_id="lead_agent", + input={"messages": [{"role": "human", "content": "hi"}]}, + metadata={}, + config=None, + context=None, + on_disconnect="cancel", + multitask_strategy="reject", + stream_mode=None, + stream_subgraphs=False, + interrupt_before=None, + interrupt_after=None, + ) + task_context: dict[str, str] = {} + + async def fake_run_agent(*args, **kwargs): + task_context["user_id"] = get_effective_user_id() + + with ( + patch("app.gateway.services.resolve_agent_factory", return_value=object()), + patch("app.gateway.services.run_agent", side_effect=fake_run_agent), + ): + record = await start_run(body, "channel-thread", request) + await record.task + + owner_run = await run_store.get(record.run_id, user_id="owner-1") + default_run = await run_store.get(record.run_id, user_id="default") + owner_thread = await thread_store.get("channel-thread", user_id="owner-1") + default_thread = await thread_store.get("channel-thread", user_id="default") + return owner_run, default_run, owner_thread, default_thread, task_context + + owner_run, default_run, owner_thread, default_thread, task_context = asyncio.run(_scenario()) + + assert owner_run is not None + assert owner_run["user_id"] == "owner-1" + assert default_run is None + assert owner_thread is not None + assert owner_thread["user_id"] == "owner-1" + assert owner_thread["metadata"] == {"legacy": True} + assert default_thread is None + assert task_context["user_id"] == "owner-1" + + # --------------------------------------------------------------------------- # build_run_config — context / configurable precedence (LangGraph >= 0.6.0) # --------------------------------------------------------------------------- diff --git a/backend/tests/test_internal_auth.py b/backend/tests/test_internal_auth.py index 7e56e1dd0..478b00d83 100644 --- a/backend/tests/test_internal_auth.py +++ b/backend/tests/test_internal_auth.py @@ -33,3 +33,18 @@ def test_internal_auth_generates_process_local_fallback(monkeypatch): assert reloaded.is_valid_internal_auth_token(token) is True finally: importlib.reload(reloaded) + + +def test_internal_auth_headers_can_carry_owner_user_id(monkeypatch): + import app.gateway.internal_auth as internal_auth + + monkeypatch.setenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", "shared-token") + reloaded = importlib.reload(internal_auth) + try: + headers = reloaded.create_internal_auth_headers(owner_user_id="owner-1") + + assert headers[reloaded.INTERNAL_AUTH_HEADER_NAME] == "shared-token" + assert headers[reloaded.INTERNAL_OWNER_USER_ID_HEADER_NAME] == "owner-1" + finally: + monkeypatch.delenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", raising=False) + importlib.reload(reloaded) diff --git a/backend/tests/test_paths_user_isolation.py b/backend/tests/test_paths_user_isolation.py index 692c526ed..5f91e32e1 100644 --- a/backend/tests/test_paths_user_isolation.py +++ b/backend/tests/test_paths_user_isolation.py @@ -44,6 +44,7 @@ class TestMakeSafeUserId: # Sanitized prefix plus a stable digest of the original. assert result.startswith("user-example-com-") assert len(result.rsplit("-", 1)[1]) == 16 + assert result == "user-example-com-b4c9a289323b21a0" assert make_safe_user_id("user@example.com") == result def test_sanitized_id_passes_validation(self, paths: Paths): @@ -69,6 +70,40 @@ class TestUserDir: def test_user_dir(self, paths: Paths): assert paths.user_dir("alice") == paths.base_dir / "users" / "alice" + def test_prepare_user_dir_migrates_unique_legacy_unsafe_bucket(self, paths: Paths): + from deerflow.config.paths import make_safe_user_id + + raw = "user@example.com" + safe = make_safe_user_id(raw) + legacy_dir = paths.base_dir / "users" / "user-example-com-63a710569261a24b" + legacy_dir.mkdir(parents=True) + (legacy_dir / "memory.json").write_text('{"legacy": true}\n', encoding="utf-8") + + assert paths.prepare_user_dir_for_raw_id(raw) == safe + + current_dir = paths.user_dir(safe) + assert current_dir.exists() + assert not legacy_dir.exists() + assert (current_dir / "memory.json").read_text(encoding="utf-8") == '{"legacy": true}\n' + + def test_prepare_user_dir_never_migrates_another_users_bucket(self, paths: Paths): + """A different raw ID with the same sanitized prefix has a different legacy digest.""" + import hashlib + + from deerflow.config.paths import make_safe_user_id + + users_dir = paths.base_dir / "users" + other_legacy = users_dir / f"a-b-{hashlib.sha1(b'a/b').hexdigest()[:16]}" + other_legacy.mkdir(parents=True) + arbitrary_16_hex = users_dir / "a-b-1111111111111111" + arbitrary_16_hex.mkdir(parents=True) + + assert paths.prepare_user_dir_for_raw_id("a.b") == make_safe_user_id("a.b") + + assert not paths.user_dir(make_safe_user_id("a.b")).exists() + assert other_legacy.exists() + assert arbitrary_16_hex.exists() + class TestUserMemoryFile: def test_user_memory_file(self, paths: Paths): diff --git a/backend/tests/test_reload_boundary.py b/backend/tests/test_reload_boundary.py index 5610ccafb..daa33b0cc 100644 --- a/backend/tests/test_reload_boundary.py +++ b/backend/tests/test_reload_boundary.py @@ -90,6 +90,7 @@ def test_appconfig_descriptions_retain_original_field_documentation(): "run_events": "memory for dev", "checkpointer": "state-persistence checkpointer", "stream_bridge": "Stream bridge", + "channel_connections": "IM channel connection", } for field_name, expected_substring in descriptions.items(): description = AppConfig.model_fields[field_name].description or "" diff --git a/backend/tests/test_setup_wizard.py b/backend/tests/test_setup_wizard.py index 9eecb2eae..5f8be4ae0 100644 --- a/backend/tests/test_setup_wizard.py +++ b/backend/tests/test_setup_wizard.py @@ -7,7 +7,9 @@ Run from repo root: from __future__ import annotations import yaml +from wizard import ui as wizard_ui from wizard.providers import LLM_PROVIDERS, SEARCH_PROVIDERS, WEB_FETCH_PROVIDERS, LLMProvider +from wizard.steps import channels as channels_step from wizard.steps import llm as llm_step from wizard.steps import search as search_step from wizard.writer import ( @@ -327,6 +329,44 @@ class TestBuildMinimalConfig: assert model["when_thinking_enabled"]["extra_body"]["thinking"]["type"] == "enabled" assert model["when_thinking_disabled"]["extra_body"]["thinking"]["type"] == "disabled" + def test_can_enable_selected_channel_connections(self): + content = build_minimal_config( + provider_use="langchain_openai:ChatOpenAI", + model_name="gpt-4o", + display_name="OpenAI", + api_key_field="api_key", + env_var="OPENAI_API_KEY", + channel_connection_providers=["feishu", "slack"], + ) + + data = yaml.safe_load(content) + channel_connections = data["channel_connections"] + + assert channel_connections["enabled"] is True + assert channel_connections["feishu"]["enabled"] is True + assert channel_connections["slack"]["enabled"] is True + assert channel_connections["telegram"]["enabled"] is False + assert channel_connections["discord"]["enabled"] is False + assert channel_connections["dingtalk"]["enabled"] is False + assert channel_connections["wechat"]["enabled"] is False + assert channel_connections["wecom"]["enabled"] is False + + def test_channel_connections_disabled_when_no_channels_selected(self): + content = build_minimal_config( + provider_use="langchain_openai:ChatOpenAI", + model_name="gpt-4o", + display_name="OpenAI", + api_key_field="api_key", + env_var="OPENAI_API_KEY", + channel_connection_providers=[], + ) + + data = yaml.safe_load(content) + channel_connections = data["channel_connections"] + + assert channel_connections["enabled"] is False + assert all(not config["enabled"] for provider, config in channel_connections.items() if provider != "enabled") + class TestLLMStep: def test_model_selection_defaults_to_provider_default_model(self, monkeypatch): @@ -384,6 +424,41 @@ class TestLLMStep: assert result.base_url == "https://gateway.example/v1" +class TestChannelsStep: + def test_returns_selected_channel_keys(self, monkeypatch): + monkeypatch.setattr(channels_step, "print_header", lambda *_args, **_kwargs: None) + monkeypatch.setattr(channels_step, "print_info", lambda *_args, **_kwargs: None) + monkeypatch.setattr(channels_step, "print_success", lambda *_args, **_kwargs: None) + monkeypatch.setattr(channels_step, "ask_multi_choice", lambda *_args, **_kwargs: [0, 3, 6]) + + result = channels_step.run_channels_step() + + assert result.enabled_providers == ["telegram", "feishu", "wecom"] + + def test_empty_selection_disables_channel_connections(self, monkeypatch): + monkeypatch.setattr(channels_step, "print_header", lambda *_args, **_kwargs: None) + monkeypatch.setattr(channels_step, "print_info", lambda *_args, **_kwargs: None) + monkeypatch.setattr(channels_step, "print_success", lambda *_args, **_kwargs: None) + monkeypatch.setattr(channels_step, "ask_multi_choice", lambda *_args, **_kwargs: []) + + result = channels_step.run_channels_step() + + assert result.enabled_providers == [] + + +class TestWizardUi: + def test_multi_choice_blank_requires_input_without_default(self, monkeypatch): + answers = iter(["", "2"]) + monkeypatch.setattr("builtins.input", lambda _prompt: next(answers)) + + assert wizard_ui.ask_multi_choice("Pick", ["First", "Second"], default=None) == [1] + + def test_multi_choice_blank_accepts_empty_default(self, monkeypatch): + monkeypatch.setattr("builtins.input", lambda _prompt: "") + + assert wizard_ui.ask_multi_choice("Pick", ["First", "Second"], default=[]) == [] + + # --------------------------------------------------------------------------- # writer.py — env file helpers # --------------------------------------------------------------------------- diff --git a/backend/tests/test_slack_channel_connections.py b/backend/tests/test_slack_channel_connections.py new file mode 100644 index 000000000..5b718bb3c --- /dev/null +++ b/backend/tests/test_slack_channel_connections.py @@ -0,0 +1,154 @@ +"""Slack connection tests for user-owned channel bindings.""" + +from __future__ import annotations + +import sys +from datetime import UTC, datetime, timedelta +from types import ModuleType +from unittest.mock import AsyncMock, MagicMock + +from app.channels.message_bus import MessageBus, OutboundMessage + + +async def _make_repo(tmp_path): + from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher + from deerflow.persistence.engine import get_session_factory, init_engine + + await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'slack.db'}", sqlite_dir=str(tmp_path)) + return ChannelConnectionRepository( + get_session_factory(), + cipher=ChannelCredentialCipher.from_key("slack-secret"), + ) + + +def test_slack_connect_command_binds_socket_mode_identity(tmp_path): + import anyio + + from app.channels.slack import SlackChannel + + async def go(): + repo = await _make_repo(tmp_path) + state = "slack-bind-code" + await repo.create_oauth_state( + owner_user_id="deerflow-user-1", + provider="slack", + state=state, + expires_at=datetime.now(UTC) + timedelta(minutes=5), + ) + channel = SlackChannel( + bus=MessageBus(), + config={"bot_token": "xoxb-operator", "app_token": "xapp-operator", "connection_repo": repo}, + ) + channel._web_client = MagicMock() + + handled = await channel._bind_connection_from_connect_code( + event={ + "user": "U123", + "channel": "C123", + "ts": "1710000000.000100", + }, + team_id="T123", + code=state, + ) + + connections = await repo.list_connections("deerflow-user-1") + assert handled is True + assert len(connections) == 1 + assert connections[0]["provider"] == "slack" + assert connections[0]["external_account_id"] == "U123" + assert connections[0]["workspace_id"] == "T123" + assert connections[0]["metadata"]["channel_id"] == "C123" + channel._web_client.chat_postMessage.assert_called_once() + await repo.close() + + anyio.run(go) + + +def test_slack_send_uses_connection_bot_token_when_connection_id_is_present(): + import anyio + + from app.channels.slack import SlackChannel + + async def go(): + repo = AsyncMock() + repo.get_credentials.return_value = {"access_token": "xoxb-connection-token"} + web_client = MagicMock() + web_client_factory = MagicMock(return_value=web_client) + channel = SlackChannel( + bus=MessageBus(), + config={ + "connection_repo": repo, + "web_client_factory": web_client_factory, + }, + ) + + msg = OutboundMessage( + channel_name="slack", + chat_id="C123", + thread_id="thread-1", + text="hello", + connection_id="connection-1", + ) + await channel.send(msg) + + repo.get_credentials.assert_awaited_once_with("connection-1") + web_client_factory.assert_called_once_with(token="xoxb-connection-token") + web_client.chat_postMessage.assert_called_once() + + anyio.run(go) + + +def test_slack_http_events_mode_initializes_operator_web_client(monkeypatch): + import anyio + + from app.channels.slack import SlackChannel + + class FakeWebClient: + def __init__(self, token: str) -> None: + self.token = token + self.messages: list[dict] = [] + + def auth_test(self): + return {"user_id": "B-http"} + + def chat_postMessage(self, **kwargs): + self.messages.append(kwargs) + + slack_sdk = ModuleType("slack_sdk") + slack_sdk.WebClient = FakeWebClient + socket_mode = ModuleType("slack_sdk.socket_mode") + socket_mode.SocketModeClient = object + response = ModuleType("slack_sdk.socket_mode.response") + response.SocketModeResponse = object + monkeypatch.setitem(sys.modules, "slack_sdk", slack_sdk) + monkeypatch.setitem(sys.modules, "slack_sdk.socket_mode", socket_mode) + monkeypatch.setitem(sys.modules, "slack_sdk.socket_mode.response", response) + + async def go(): + channel = SlackChannel( + bus=MessageBus(), + config={ + "bot_token": "xoxb-operator", + "event_delivery": "http", + "connection_repo": MagicMock(), + }, + ) + + await channel.start() + assert channel._running is True + assert channel._web_client is not None + assert channel._web_client.token == "xoxb-operator" + assert channel._bot_user_id == "B-http" + + await channel._post_connection_reply("C123", "Slack connected to DeerFlow.", "1710000000.000100") + + assert channel._web_client.messages == [ + { + "channel": "C123", + "text": "Slack connected to DeerFlow.", + "thread_ts": "1710000000.000100", + } + ] + await channel.stop() + + anyio.run(go) diff --git a/backend/tests/test_stateless_runs_owner_isolation.py b/backend/tests/test_stateless_runs_owner_isolation.py index 60a20d17c..6d6521238 100644 --- a/backend/tests/test_stateless_runs_owner_isolation.py +++ b/backend/tests/test_stateless_runs_owner_isolation.py @@ -164,10 +164,42 @@ def test_stream_shared_thread_passes_owner_check(): create_or_reject.assert_awaited() -def test_stream_internal_role_bypasses_owner_check(): - """IM channels run with the internal system role on behalf of platform - users whose threads they do not own — the owner check must not break them.""" +def test_stream_internal_role_scoped_by_owner_header(): + """IM channels run with the internal system role on behalf of the + connection owner named in X-DeerFlow-Owner-User-Id — the owner check is + scoped to that owner rather than bypassed.""" + from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME + with _client(INTERNAL_USER) as (client, create_or_reject): - response = client.post("/api/runs/stream", json=_body(THREAD_A)) + response = client.post( + "/api/runs/stream", + json=_body(THREAD_A), + headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: str(USER_A.id)}, + ) assert response.status_code == 409 create_or_reject.assert_awaited() + + +def test_stream_internal_role_with_foreign_owner_header_returns_404(): + """The internal token alone must not grant access to another user's thread.""" + from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME + + with _client(INTERNAL_USER) as (client, create_or_reject): + response = client.post( + "/api/runs/stream", + json=_body(THREAD_A), + headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: str(USER_B.id)}, + ) + assert response.status_code == 404 + create_or_reject.assert_not_awaited() + + +def test_stream_internal_role_without_owner_header_is_scoped_to_internal_user(): + """Without an owner header internal callers keep access to their own and + shared/untracked threads, but not to user-owned threads.""" + with _client(INTERNAL_USER) as (client, create_or_reject): + denied = client.post("/api/runs/stream", json=_body(THREAD_A)) + allowed = client.post("/api/runs/stream", json=_body(THREAD_SHARED)) + assert denied.status_code == 404 + assert allowed.status_code == 409 + create_or_reject.assert_awaited() diff --git a/backend/tests/test_telegram_channel_connections.py b/backend/tests/test_telegram_channel_connections.py new file mode 100644 index 000000000..b304e2abb --- /dev/null +++ b/backend/tests/test_telegram_channel_connections.py @@ -0,0 +1,100 @@ +"""Tests for Telegram deep-link channel connections.""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from app.channels.message_bus import MessageBus +from app.channels.telegram import TelegramChannel + + +@pytest.fixture +async def repo(tmp_path: Path): + from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher + from deerflow.persistence.engine import close_engine, get_session_factory, init_engine + + await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'telegram.db'}", sqlite_dir=str(tmp_path)) + try: + yield ChannelConnectionRepository( + get_session_factory(), + cipher=ChannelCredentialCipher.from_key("telegram-secret"), + ) + finally: + await close_engine() + + +def _telegram_update(*, text: str = "/start", user_id: int = 42, chat_id: int = 100, chat_type: str = "private"): + update = MagicMock() + update.effective_user.id = user_id + update.effective_user.username = "alice" + update.effective_user.full_name = "Alice Example" + update.effective_chat.id = chat_id + update.effective_chat.type = chat_type + update.message.text = text + update.message.message_id = 55 + update.message.reply_to_message = None + update.message.reply_text = AsyncMock() + return update + + +@pytest.mark.anyio +async def test_start_with_deep_link_state_binds_telegram_chat(repo): + state = "telegram-bind-state" + await repo.create_oauth_state( + owner_user_id="deerflow-user-1", + provider="telegram", + state=state, + expires_at=datetime.now(UTC) + timedelta(minutes=5), + ) + channel = TelegramChannel( + bus=MessageBus(), + config={"bot_token": "test-token", "connection_repo": repo}, + ) + update = _telegram_update(text=f"/start {state}") + context = MagicMock() + context.args = [state] + + await channel._cmd_start(update, context) + + connections = await repo.list_connections("deerflow-user-1") + assert len(connections) == 1 + assert connections[0]["provider"] == "telegram" + assert connections[0]["external_account_id"] == "42" + assert connections[0]["external_account_name"] == "Alice Example" + assert connections[0]["workspace_id"] == "100" + assert connections[0]["metadata"]["chat_type"] == "private" + update.message.reply_text.assert_awaited_once() + assert "connected" in update.message.reply_text.await_args.args[0].lower() + + +@pytest.mark.anyio +async def test_bound_telegram_message_publishes_connection_identity(repo): + connection = await repo.upsert_connection( + owner_user_id="deerflow-user-1", + provider="telegram", + external_account_id="42", + external_account_name="Alice Example", + workspace_id="100", + metadata={"chat_type": "private"}, + ) + bus = MessageBus() + channel = TelegramChannel( + bus=bus, + config={"bot_token": "test-token", "connection_repo": repo}, + ) + channel._main_loop = __import__("asyncio").get_event_loop() + channel._send_running_reply = AsyncMock() + + await channel._on_text(_telegram_update(text="hello"), None) + inbound = await bus.get_inbound() + + assert inbound.connection_id == connection["id"] + assert inbound.owner_user_id == "deerflow-user-1" + assert inbound.workspace_id == "100" + assert inbound.user_id == "42" + assert inbound.chat_id == "100" + assert inbound.text == "hello" diff --git a/backend/tests/test_thread_meta_repo.py b/backend/tests/test_thread_meta_repo.py index 1cef3752b..c6fff8868 100644 --- a/backend/tests/test_thread_meta_repo.py +++ b/backend/tests/test_thread_meta_repo.py @@ -137,6 +137,19 @@ class TestThreadMetaRepository: async def test_update_metadata_nonexistent_is_noop(self, repo): await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise + @pytest.mark.anyio + async def test_update_owner_with_bypass_moves_row(self, repo): + await repo.create("t1", user_id="default", metadata={"source": "channel"}) + await repo.update_owner("t1", "owner-1", user_id=None) + + owner_row = await repo.get("t1", user_id="owner-1") + default_row = await repo.get("t1", user_id="default") + + assert owner_row is not None + assert owner_row["user_id"] == "owner-1" + assert owner_row["metadata"] == {"source": "channel"} + assert default_row is None + # --- search with metadata filter (SQL push-down) --- @pytest.mark.anyio diff --git a/backend/tests/test_threads_router.py b/backend/tests/test_threads_router.py index f6f6adcef..74e4c7a50 100644 --- a/backend/tests/test_threads_router.py +++ b/backend/tests/test_threads_router.py @@ -1,4 +1,5 @@ import re +from types import SimpleNamespace from unittest.mock import patch import pytest @@ -218,6 +219,37 @@ def test_create_thread_returns_iso_timestamps() -> None: assert body["created_at"] == body["updated_at"] +def test_internal_owner_header_assigns_thread_to_owner() -> None: + import asyncio + + from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE + + store = InMemoryStore() + checkpointer = InMemorySaver() + thread_store = MemoryThreadMetaStore(store) + request = SimpleNamespace( + headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "owner-1"}, + state=SimpleNamespace(user=SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)), + app=SimpleNamespace(state=SimpleNamespace(checkpointer=checkpointer, thread_store=thread_store)), + ) + + async def _scenario(): + response = await threads.create_thread( + threads.ThreadCreateRequest(thread_id="channel-thread", metadata={}), + request, + ) + owner_row = await thread_store.get("channel-thread", user_id="owner-1") + internal_row = await thread_store.get("channel-thread", user_id="default") + return response, owner_row, internal_row + + response, owner_row, internal_row = asyncio.run(_scenario()) + + assert response.thread_id == "channel-thread" + assert owner_row is not None + assert owner_row["user_id"] == "owner-1" + assert internal_row is None + + def test_get_thread_returns_iso_for_legacy_unix_record() -> None: """A thread record written by older versions stores ``time.time()`` floats. ``get_thread`` must transparently surface them as ISO so the diff --git a/backend/uv.lock b/backend/uv.lock index f4008b9a1..91627c878 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -820,6 +820,7 @@ dependencies = [ { name = "agent-sandbox" }, { name = "aiosqlite" }, { name = "alembic" }, + { name = "cryptography" }, { name = "ddgs" }, { name = "dotenv" }, { name = "duckdb" }, @@ -871,6 +872,7 @@ requires-dist = [ { name = "aiosqlite", specifier = ">=0.19" }, { name = "alembic", specifier = ">=1.13" }, { name = "asyncpg", marker = "extra == 'postgres'", specifier = ">=0.29" }, + { name = "cryptography", specifier = ">=43.0.0" }, { name = "ddgs", specifier = ">=9.10.0" }, { name = "dotenv", specifier = ">=0.9.9" }, { name = "duckdb", specifier = ">=1.4.4" }, diff --git a/config.example.yaml b/config.example.yaml index 75c83228f..73af462f6 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -1140,6 +1140,45 @@ run_events: max_trace_content: 10240 track_token_usage: true +# ============================================================================ +# User-Owned IM Channel Connections +# ============================================================================ +# Lets logged-in users connect their own IM accounts from the DeerFlow frontend +# while reusing the existing `channels` runtime configuration below. +# +# Security notes: +# - No public IP, OAuth callback URL, or provider webhook is required. +# - Provider bot/app credentials stay under `channels.*`. +# - `channel_connections` stores per-user bindings and one-time connect codes. +# - Telegram uses a deep link when `bot_username` is configured. +# - Slack, Discord, Feishu, DingTalk, WeChat, and WeCom use `/connect ` +# through the already-running bot/app. +# +# channel_connections: +# enabled: false +# +# telegram: +# enabled: false +# bot_username: $TELEGRAM_BOT_USERNAME +# +# slack: +# enabled: false +# +# discord: +# enabled: false +# +# feishu: +# enabled: false +# +# dingtalk: +# enabled: false +# +# wechat: +# enabled: false +# +# wecom: +# enabled: false + # ============================================================================ # IM Channels Configuration # ============================================================================ diff --git a/frontend/AGENTS.md b/frontend/AGENTS.md index 036927a2b..0d4e770b0 100644 --- a/frontend/AGENTS.md +++ b/frontend/AGENTS.md @@ -52,6 +52,7 @@ src/ ├── core/ # Core business logic │ ├── api/ # API client & data fetching │ ├── artifacts/ # Artifact management +│ ├── channels/ # IM channel connections (providers, connect flow) │ ├── config/ # App configuration │ ├── i18n/ # Internationalization │ ├── mcp/ # MCP integration diff --git a/frontend/CLAUDE.md b/frontend/CLAUDE.md index d431670d2..33da1f276 100644 --- a/frontend/CLAUDE.md +++ b/frontend/CLAUDE.md @@ -48,6 +48,7 @@ The frontend is a stateful chat application. Users create **threads** (conversat - `threads/` — Thread creation, streaming, state management (hooks + types) - `api/` — LangGraph client singleton - `artifacts/` — Artifact loading and caching + - `channels/` — IM channel connections (provider catalog, connect/runtime-config API + hooks) - `i18n/` — Internationalization (en-US, zh-CN) - `settings/` — User preferences in localStorage - `memory/` — Persistent user memory system diff --git a/frontend/src/app/workspace/chats/page.tsx b/frontend/src/app/workspace/chats/page.tsx index fdd4dd454..faaeb8ff2 100644 --- a/frontend/src/app/workspace/chats/page.tsx +++ b/frontend/src/app/workspace/chats/page.tsx @@ -6,6 +6,10 @@ import { useEffect, useMemo, useRef, useState } from "react"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { ScrollArea } from "@/components/ui/scroll-area"; +import { + ThreadChannelBadge, + ThreadChannelIcon, +} from "@/components/workspace/thread-channel-source"; import { WorkspaceBody, WorkspaceContainer, @@ -13,7 +17,11 @@ import { } from "@/components/workspace/workspace-container"; import { useI18n } from "@/core/i18n/hooks"; import { useInfiniteThreads } from "@/core/threads/hooks"; -import { pathOfThread, titleOfThread } from "@/core/threads/utils"; +import { + channelSourceOfThread, + pathOfThread, + titleOfThread, +} from "@/core/threads/utils"; import { formatTimeAgo } from "@/core/utils/datetime"; export default function ChatsPage() { @@ -82,20 +90,30 @@ export default function ChatsPage() {
- {filteredThreads?.map((thread) => ( - -
-
-
{titleOfThread(thread)}
-
- {thread.updated_at && ( -
- {formatTimeAgo(thread.updated_at)} + {filteredThreads.map((thread) => { + const channelSource = channelSourceOfThread(thread); + return ( + +
+
+ +
+ {titleOfThread(thread)} +
+
- )} -
- - ))} + {thread.updated_at && ( +
+ {formatTimeAgo(thread.updated_at)} +
+ )} +
+ + ); + })} {hasNextPage && !isSearching && (
& { + provider: string; +}; + +export function ChannelProviderIcon({ + provider, + className, + ...props +}: ChannelProviderIconProps) { + const normalizedProvider = provider.toLowerCase(); + + if (normalizedProvider === "telegram") { + return ( + + ); + } + + if (normalizedProvider === "slack") { + return ( + + ); + } + + if (normalizedProvider === "discord") { + return ( + + ); + } + + if (normalizedProvider === "feishu") { + return ( + + ); + } + + if (normalizedProvider === "dingtalk") { + return ( + + ); + } + + if (normalizedProvider === "wechat") { + return ( + + ); + } + + if (normalizedProvider === "wecom") { + return ( + + ); + } + + return ( +