feat(im): Add user-owned IM channel connections (#3487)

* Add user-owned IM channel connections

* Fix dev startup and channel connect popup

* Use async channel connect flow

* Harden dev service daemon startup

* Support local IM channel connections

* Align IM connections with local channels

* Fix safe user id digest algorithm

* Address Copilot IM channel feedback

* Address IM channel review comments

* Support all integrated IM channel connections

* Format additional channel connection tests

* Keep unavailable channel connect buttons clickable

* Fix IM channel provider icons

* Add runtime setup for enabled IM channels

* Guard global shortcut key handling

* Keep configured IM channels editable

* Avoid password autofill for channel secrets

* Make channel threads visible to connection owners

* Persist IM runtime config locally

* Allow disconnecting runtime IM channels

* Route no-auth channel sessions to local user

* Use default user for auth-disabled local mode

* Show IM channel source on threads

* Prefill IM channel runtime config

* Reflect IM channel runtime health

* Ignore Feishu message read events

* Ignore Feishu non-content message events

* Let setup wizard enable IM channels

* Fix frontend formatting after merge

* Stabilize backend tests without local config

* Isolate channel runtime config tests

* Address channel connection review comments

* Use sha256 user buckets with legacy migration

* Ensure runtime IM channels are ready after restart

* Persist disconnected IM channel state

* Address channel connection review comments

* Address channel connection review findings

Frontend connect flow:
- Open the runtime-config dialog only when a provider still needs
  credentials; configured providers go straight to the connect flow, so
  the binding-code/deep-link path is reachable from the UI again.
- After saving credentials, continue into the connect flow when a user
  binding is still required (multi-user mode) instead of stopping at a
  "Connected" toast.
- Extract shared provider-state helpers to core/channels/provider-state
  and add unit + e2e coverage for the direct-connect and
  configure-then-connect paths.

Provider status semantics:
- Report connection_status from the user's newest connection row;
  with no binding it is not_connected, except in auth-disabled local
  mode where a configured running channel is effectively connected.

Concurrency and event-loop correctness:
- Offload ChannelRuntimeConfigStore construction and writes, channel
  service construction, and Slack connection replies to threads; add a
  tests/blocking_io/ anchor for the runtime-config handlers.
- Consume binding codes with a conditional UPDATE so a code can only be
  used once under concurrent workers; retry upsert_connection as an
  update when a concurrent insert wins the unique constraint.
- Serialize ensure_channel_ready per channel so concurrent provider
  polls cannot double-start a channel worker.

Config and migration hardening:
- Stop mutating the get_app_config()-cached Telegram provider config;
  the runtime store now owns the UI-entered bot username.
- Register channel_connections in STARTUP_ONLY_FIELDS with the
  standardized startup-only Field description.
- Match the legacy unsafe-id bucket by recomputing its exact SHA-1 name
  so another user's same-prefix bucket can never be migrated.
- Remove the unused Telegram process_webhook_update path and document
  src/core/channels in the frontend docs.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

* Address PR review comments on authz scoping and channel runtime

Security (review feedback from ShenAC-SAC):
- Scope internal-token callers to the connection owner carried in
  X-DeerFlow-Owner-User-Id instead of bypassing owner checks outright,
  in both require_permission(owner_check=True) and the stateless run
  endpoints. Internal callers keep access to their own and
  shared/legacy threads, and may claim a default-owned channel thread
  for its real owner, but a leaked internal token no longer grants
  cross-user thread access.
- Require admin privileges for POST/DELETE /api/channels/{provider}/
  runtime-config: runtime credentials and channel workers are
  instance-wide shared state (same model as the MCP config API).
  Read-only provider listing stays available to all users.

Performance (review feedback from willem-bd):
- Skip the redundant thread channel-metadata PATCH after the first
  successful backfill per thread.
- Reuse the per-connection Slack WebClient until its token changes
  instead of constructing one per outbound message.
- Reconcile channel readiness for all providers concurrently in
  GET /api/channels/providers.

Also resolve the code-quality unused-import flag in the blocking-io
anchor by pre-importing the channel service via importlib.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

* Fix prettier formatting in provider-state test

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

* Reconcile UI runtime channel config with config reload on restart

Main now reloads a channel's config.yaml entry on restart_channel()
(#3514, issue #3497). Adapt the user-owned connection flow to coexist:

- configure_channel() restarts with reload_config=False — the caller
  just supplied the authoritative config (browser-entered credentials
  that are never written to config.yaml), so a file reload must not
  clobber it with the stale on-disk entry.
- _load_channel_config() re-applies the UI runtime-store overlay used
  at startup, so an operator-triggered restart keeps browser-entered
  credentials for channels without a config.yaml entry and does not
  resurrect a channel disconnected from the UI.
- Offload the reload's disk IO (config.yaml + runtime store) with
  asyncio.to_thread, matching the blocking-IO policy on this branch.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

---------

Co-authored-by: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
DanielWalnut
2026-06-12 15:24:58 +08:00
committed by GitHub
parent b8f5ed360f
commit aa015462a7
96 changed files with 8585 additions and 277 deletions
+2
View File
@@ -343,6 +343,8 @@ See the [MCP Server Guide](backend/docs/MCP_SERVER.md) for detailed instructions
DeerFlow supports receiving tasks from messaging apps. Channels auto-start when configured — no public IP required for any of them. DeerFlow supports receiving tasks from messaging apps. Channels auto-start when configured — no public IP required for any of them.
DeerFlow can also expose user-owned IM channel connections in the workspace UI. When `channel_connections` is enabled, logged-in users can bind Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, or WeCom from the sidebar / Settings > Channels. It reuses the existing outbound `channels.*` transports, so no public IP or provider callback URL is required. Incoming IM messages then run under the connected DeerFlow user account. See [IM Channel Connections](backend/docs/IM_CHANNEL_CONNECTIONS.md) for setup and security notes.
| Channel | Transport | Difficulty | | Channel | Transport | Difficulty |
|---------|-----------|------------| |---------|-----------|------------|
| Telegram | Bot API (long-polling) | Easy | | Telegram | Bot API (long-polling) | Easy |
+24 -11
View File
@@ -234,7 +234,7 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc
**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state``lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work and passes it to `langgraph_runtime(app, startup_config)`. **Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state``lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work and passes it to `langgraph_runtime(app, startup_config)`.
Infrastructure fields are **restart-required**. The authoritative list lives in `packages/harness/deerflow/config/reload_boundary.py::STARTUP_ONLY_FIELDS` and is mirrored by the standardised `"startup-only:"` prefix on the corresponding `Field(description=...)` in `AppConfig`, so IDE hover on those fields surfaces the reason inline (no need to context-switch into this table). Currently registered: `database`, `checkpointer`, `run_events`, `stream_bridge`, `sandbox`, `log_level`, `channels`. Adding a new restart-required field requires updating the registry; drift is pinned by `tests/test_reload_boundary.py`. Infrastructure fields are **restart-required**. The authoritative list lives in `packages/harness/deerflow/config/reload_boundary.py::STARTUP_ONLY_FIELDS` and is mirrored by the standardised `"startup-only:"` prefix on the corresponding `Field(description=...)` in `AppConfig`, so IDE hover on those fields surfaces the reason inline (no need to context-switch into this table). Currently registered: `database`, `checkpointer`, `run_events`, `stream_bridge`, `sandbox`, `log_level`, `channels`, `channel_connections`. Adding a new restart-required field requires updating the registry; drift is pinned by `tests/test_reload_boundary.py`.
Configuration priority: Configuration priority:
1. Explicit `config_path` argument 1. Explicit `config_path` argument
@@ -377,8 +377,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
### IM Channels System (`app/channels/`) ### IM Channels System (`app/channels/`)
Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the DeerFlow agent via Gateway's LangGraph-compatible API. Bridges external messaging platforms (Feishu, Slack, Telegram, Discord, DingTalk) to the DeerFlow agent via Gateway's LangGraph-compatible API.
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies. **Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies.
@@ -388,18 +387,21 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
- `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Telegram on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu incremental outbound updates - `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Telegram on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu incremental outbound updates
- `base.py` - Abstract `Channel` base class (start/stop/send lifecycle) - `base.py` - Abstract `Channel` base class (start/stop/send lifecycle)
- `service.py` - Manages lifecycle of all configured channels from `config.yaml` - `service.py` - Manages lifecycle of all configured channels from `config.yaml`
- `slack.py` / `feishu.py` / `telegram.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured) - `slack.py` / `feishu.py` / `telegram.py` / `discord.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured)
- `app/gateway/routers/channel_connections.py` - Browser-facing user connection and disconnect APIs
- `deerflow.persistence.channel_connections` - SQL-backed user-owned connection, optional credential, connect state, and conversation store
**Message Flow**: **Message Flow**:
1. External platform -> Channel impl -> `MessageBus.publish_inbound()` 1. External platform -> Channel impl -> `MessageBus.publish_inbound()`
2. `ChannelManager._dispatch_loop()` consumes from queue 2. `ChannelManager._dispatch_loop()` consumes from queue
3. For chat: look up/create thread through Gateway's LangGraph-compatible API 3. For user-owned channel connections, incoming messages carry `connection_id`, `owner_user_id`, and `workspace_id`; `owner_user_id` becomes the DeerFlow run `user_id`, while the raw platform user id remains `channel_user_id`
4. Feishu chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`) 4. For chat: look up/create thread through Gateway's LangGraph-compatible API
5. Slack/Telegram chat: `runs.wait()`extract final response → publish outbound 5. Feishu chat: `runs.stream()`accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`)
6. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement) 6. Slack/Telegram chat: `runs.wait()` → extract final response → publish outbound
7. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails 7. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement)
8. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API 8. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails
9. Outbound → channel callbacks → platform reply 9. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
10. Outbound → channel callbacks → platform reply
**Configuration** (`config.yaml` -> `channels`): **Configuration** (`config.yaml` -> `channels`):
- `langgraph_url` - LangGraph-compatible Gateway API base URL (default: `http://localhost:8001/api`) - `langgraph_url` - LangGraph-compatible Gateway API base URL (default: `http://localhost:8001/api`)
@@ -407,6 +409,17 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
- In Docker Compose, IM channels run inside the `gateway` container, so `localhost` points back to that container. Use `http://gateway:8001/api` for `langgraph_url` and `http://gateway:8001` for `gateway_url`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` / `DEER_FLOW_CHANNELS_GATEWAY_URL`. - In Docker Compose, IM channels run inside the `gateway` container, so `localhost` points back to that container. Use `http://gateway:8001/api` for `langgraph_url` and `http://gateway:8001` for `gateway_url`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` / `DEER_FLOW_CHANNELS_GATEWAY_URL`.
- Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token), `dingtalk` (client_id, client_secret, optional `card_template_id` for AI Card streaming) - Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token), `dingtalk` (client_id, client_secret, optional `card_template_id` for AI Card streaming)
**User-owned channel connections** (`config.yaml` -> `channel_connections`):
- Disabled by default. It is a user-binding layer on top of the existing `channels.*` runtime config, not a replacement for provider bot credentials.
- No public IP, OAuth callback URL, or provider webhook route is required by the current implementation.
- Telegram uses a deep-link `/start <code>` flow over the existing long-polling worker. Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom use `/connect <code>` over their existing outbound channel workers.
- Frontend APIs: `GET /api/channels/providers`, `GET /api/channels/connections`, `POST /api/channels/{provider}/connect`, and `DELETE /api/channels/connections/{connection_id}`.
- Browser APIs remain protected by normal Gateway auth/CSRF. Provider messages arrive through the already-configured channel workers.
- Provider-level `connection_status` reflects the user's newest connection row. With no binding it is `not_connected`, except in auth-disabled local mode where a configured running channel reports `connected` because all channel messages already route to the default user.
- Slack replies use the configured operator bot token from `channels.slack` unless per-connection credentials are present; unreadable or corrupt stored credentials are treated as unavailable.
- Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom workers resolve incoming platform identities to connection records before reaching `ChannelManager`.
- See `backend/docs/IM_CHANNEL_CONNECTIONS.md` for provider setup and operational notes.
### Memory System (`packages/harness/deerflow/agents/memory/`) ### Memory System (`packages/harness/deerflow/agents/memory/`)
+11
View File
@@ -20,6 +20,17 @@ KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset(
) )
def extract_connect_code(text: str) -> str | None:
"""Extract the one-time channel binding code from a connect command."""
parts = text.strip().split()
if len(parts) < 2:
return None
command = parts[0].lower()
if command in {"/connect", "connect"}:
return parts[1]
return None
def is_known_channel_command(text: str) -> bool: def is_known_channel_command(text: str) -> bool:
"""Return whether text starts with a registered channel control command.""" """Return whether text starts with a registered channel control command."""
if not text.startswith("/"): if not text.startswith("/"):
@@ -0,0 +1,44 @@
"""Helpers for attaching persisted channel connection ownership to inbound messages."""
from __future__ import annotations
from typing import Any
from app.channels.message_bus import InboundMessage
async def attach_connection_identity(
inbound: InboundMessage,
*,
repo: Any,
provider: str,
workspace_id: str | None,
fallback_without_workspace: bool = False,
) -> InboundMessage:
"""Attach connection metadata to an inbound message when a persisted binding exists."""
if repo is None:
return inbound
workspace_candidates: list[str | None] = []
if workspace_id:
workspace_candidates.append(workspace_id)
if fallback_without_workspace:
workspace_candidates.append(None)
if not workspace_candidates:
return inbound
for candidate in workspace_candidates:
connection = await repo.find_connection_by_external_identity(
provider=provider,
external_account_id=inbound.user_id,
workspace_id=candidate,
)
if connection is None:
continue
inbound.connection_id = connection["id"]
inbound.owner_user_id = connection["owner_user_id"]
inbound.workspace_id = connection.get("workspace_id")
return inbound
return inbound
+105 -1
View File
@@ -14,7 +14,8 @@ from typing import Any
import httpx import httpx
from app.channels.base import Channel from app.channels.base import Channel
from app.channels.commands import is_known_channel_command from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -136,6 +137,7 @@ class DingTalkChannel(Channel):
self._incoming_messages: dict[str, Any] = {} self._incoming_messages: dict[str, Any] = {}
self._incoming_messages_lock = threading.Lock() self._incoming_messages_lock = threading.Lock()
self._card_repliers: dict[str, Any] = {} self._card_repliers: dict[str, Any] = {}
self._connection_repo = config.get("connection_repo")
@property @property
def supports_streaming(self) -> bool: def supports_streaming(self) -> bool:
@@ -395,6 +397,24 @@ class DingTalkChannel(Channel):
text[:100], text[:100],
) )
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(
self._bind_connection_from_connect_code(
conversation_type=conversation_type,
sender_staff_id=sender_staff_id,
sender_nick=sender_nick,
conversation_id=conversation_id,
code=connect_code,
),
self._main_loop,
)
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
else:
logger.warning("[DingTalk] main loop not running, cannot bind channel connection")
return
if _is_dingtalk_command(text): if _is_dingtalk_command(text):
msg_type = InboundMessageType.COMMAND msg_type = InboundMessageType.COMMAND
else: else:
@@ -450,11 +470,95 @@ class DingTalkChannel(Channel):
return "" return ""
async def _prepare_inbound(self, chat_id: str, inbound: InboundMessage) -> None: async def _prepare_inbound(self, chat_id: str, inbound: InboundMessage) -> None:
inbound = await self._attach_connection_identity(inbound)
# Running reply must finish before publish_inbound so AI card tracks are # Running reply must finish before publish_inbound so AI card tracks are
# registered before the manager emits streaming outbounds. # registered before the manager emits streaming outbounds.
await self._send_running_reply(chat_id, inbound) await self._send_running_reply(chat_id, inbound)
await self.bus.publish_inbound(inbound) await self.bus.publish_inbound(inbound)
@staticmethod
def _connection_workspace_id(conversation_type: str, conversation_id: str) -> str | None:
if conversation_type == _CONVERSATION_TYPE_GROUP and conversation_id:
return conversation_id
return None
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
conversation_type = str(inbound.metadata.get("conversation_type") or _CONVERSATION_TYPE_P2P)
conversation_id = str(inbound.metadata.get("conversation_id") or "")
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="dingtalk",
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
fallback_without_workspace=True,
)
async def _bind_connection_from_connect_code(
self,
*,
conversation_type: str,
sender_staff_id: str,
sender_nick: str,
conversation_id: str,
code: str,
) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="dingtalk", state=code)
if state is None:
await self._send_connection_reply(
conversation_type,
sender_staff_id,
conversation_id,
"DingTalk connection code is invalid or expired.",
)
return True
if not sender_staff_id:
await self._send_connection_reply(
conversation_type,
sender_staff_id,
conversation_id,
"DingTalk connection could not be completed from this message.",
)
return True
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="dingtalk",
external_account_id=sender_staff_id,
external_account_name=sender_nick or None,
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
metadata={
"conversation_type": conversation_type,
"conversation_id": conversation_id,
},
status="connected",
)
await self._send_connection_reply(
conversation_type,
sender_staff_id,
conversation_id,
"DingTalk connected to DeerFlow.",
)
return True
async def _send_connection_reply(
self,
conversation_type: str,
sender_staff_id: str,
conversation_id: str,
text: str,
) -> None:
robot_code = self._client_id
if conversation_type == _CONVERSATION_TYPE_GROUP:
if conversation_id:
await self._send_text_message_to_group(robot_code, conversation_id, text)
return
if sender_staff_id:
await self._send_text_message_to_user(robot_code, sender_staff_id, text)
async def _send_running_reply(self, chat_id: str, inbound: InboundMessage) -> None: async def _send_running_reply(self, chat_id: str, inbound: InboundMessage) -> None:
conversation_type = inbound.metadata.get("conversation_type", _CONVERSATION_TYPE_P2P) conversation_type = inbound.metadata.get("conversation_type", _CONVERSATION_TYPE_P2P)
sender_staff_id = inbound.metadata.get("sender_staff_id", "") sender_staff_id = inbound.metadata.get("sender_staff_id", "")
+64 -2
View File
@@ -10,8 +10,9 @@ from pathlib import Path
from typing import Any from typing import Any
from app.channels.base import Channel from app.channels.base import Channel
from app.channels.commands import is_known_channel_command from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -70,6 +71,7 @@ class DiscordChannel(Channel):
self._discord_loop: asyncio.AbstractEventLoop | None = None self._discord_loop: asyncio.AbstractEventLoop | None = None
self._main_loop: asyncio.AbstractEventLoop | None = None self._main_loop: asyncio.AbstractEventLoop | None = None
self._discord_module = None self._discord_module = None
self._connection_repo = config.get("connection_repo")
async def start(self) -> None: async def start(self) -> None:
if self._running: if self._running:
@@ -287,6 +289,10 @@ class DiscordChannel(Channel):
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip() text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
# Don't return early if text is empty — still process the mention (e.g., create thread) # Don't return early if text is empty — still process the mention (e.g., create thread)
connect_code = extract_connect_code(text)
if connect_code and await self._bind_connection_from_connect_code(message, connect_code):
return
# --- Determine thread/channel routing and typing target --- # --- Determine thread/channel routing and typing target ---
thread_id = None thread_id = None
chat_id = None chat_id = None
@@ -315,6 +321,7 @@ class DiscordChannel(Channel):
}, },
) )
inbound.topic_id = thread_id inbound.topic_id = thread_id
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
self._publish(inbound) self._publish(inbound)
# Start typing indicator in the thread # Start typing indicator in the thread
if typing_target: if typing_target:
@@ -422,6 +429,7 @@ class DiscordChannel(Channel):
}, },
) )
inbound.topic_id = thread_id inbound.topic_id = thread_id
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
# Start typing indicator in the correct target (thread or channel) # Start typing indicator in the correct target (thread or channel)
if typing_target: if typing_target:
@@ -436,6 +444,60 @@ class DiscordChannel(Channel):
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop) future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None) future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
async def _attach_connection_identity(self, inbound: InboundMessage, guild_id: str | None = None) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="discord",
workspace_id=guild_id,
fallback_without_workspace=True,
)
async def _bind_connection_from_connect_code(self, message, code: str) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="discord", state=code)
if state is None:
await self._send_connection_reply(message, "Discord connection code is invalid or expired.")
return True
guild = getattr(message, "guild", None)
channel = getattr(message, "channel", None)
author = getattr(message, "author", None)
user_id = str(getattr(author, "id", "") or "")
if not user_id:
await self._send_connection_reply(message, "Discord connection could not be completed from this message.")
return True
guild_id = str(getattr(guild, "id", "") or "") or None
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="discord",
external_account_id=user_id,
external_account_name=getattr(author, "display_name", None) or getattr(author, "name", None),
workspace_id=guild_id,
workspace_name=getattr(guild, "name", None) if guild is not None else None,
metadata={
"guild_id": guild_id,
"channel_id": str(getattr(channel, "id", "") or ""),
},
status="connected",
)
await self._send_connection_reply(message, "Discord connected to DeerFlow.")
return True
@staticmethod
async def _send_connection_reply(message, text: str) -> None:
channel = getattr(message, "channel", None)
send = getattr(channel, "send", None)
if send is None:
return
try:
await send(text)
except Exception:
logger.exception("[Discord] failed to send connection reply")
def _run_client(self) -> None: def _run_client(self) -> None:
self._discord_loop = asyncio.new_event_loop() self._discord_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._discord_loop) asyncio.set_event_loop(self._discord_loop)
+78 -2
View File
@@ -11,7 +11,8 @@ import time
from typing import Any, Literal from typing import Any, Literal
from app.channels.base import Channel from app.channels.base import Channel
from app.channels.commands import is_known_channel_command from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import ( from app.channels.message_bus import (
PENDING_CLARIFICATION_METADATA_KEY, PENDING_CLARIFICATION_METADATA_KEY,
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY, RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY,
@@ -71,6 +72,7 @@ class FeishuChannel(Channel):
self._CreateImageRequestBody = None self._CreateImageRequestBody = None
self._GetMessageResourceRequest = None self._GetMessageResourceRequest = None
self._thread_lock = threading.Lock() self._thread_lock = threading.Lock()
self._connection_repo = config.get("connection_repo")
@staticmethod @staticmethod
def _non_empty_str(value: Any) -> str | None: def _non_empty_str(value: Any) -> str | None:
@@ -86,6 +88,23 @@ class FeishuChannel(Channel):
def supports_streaming(self) -> bool: def supports_streaming(self) -> bool:
return True return True
@property
def is_running(self) -> bool:
if not self._running:
return False
return self._thread is not None and self._thread.is_alive()
def _build_event_handler(self, lark):
return (
lark.EventDispatcherHandler.builder("", "")
.register_p2_im_message_receive_v1(self._on_message)
.register_p2_im_message_message_read_v1(self._on_ignored_message_event)
.register_p2_im_message_reaction_created_v1(self._on_ignored_message_event)
.register_p2_im_message_reaction_deleted_v1(self._on_ignored_message_event)
.register_p2_im_message_recalled_v1(self._on_ignored_message_event)
.build()
)
async def start(self) -> None: async def start(self) -> None:
if self._running: if self._running:
return return
@@ -179,7 +198,7 @@ class FeishuChannel(Channel):
# thread's uvloop. # thread's uvloop.
_ws_client_mod.loop = loop _ws_client_mod.loop = loop
event_handler = lark.EventDispatcherHandler.builder("", "").register_p2_im_message_receive_v1(self._on_message).build() event_handler = self._build_event_handler(lark)
ws_client = lark.ws.Client( ws_client = lark.ws.Client(
app_id=app_id, app_id=app_id,
app_secret=app_secret, app_secret=app_secret,
@@ -191,6 +210,10 @@ class FeishuChannel(Channel):
except Exception: except Exception:
if self._running: if self._running:
logger.exception("Feishu WebSocket error") logger.exception("Feishu WebSocket error")
self._running = False
def _on_ignored_message_event(self, event) -> None:
logger.debug("[Feishu] ignoring non-content message event: %s", type(event).__name__)
async def stop(self) -> None: async def stop(self) -> None:
self._running = False self._running = False
@@ -726,11 +749,47 @@ class FeishuChannel(Channel):
async def _prepare_inbound(self, msg_id: str, inbound) -> None: async def _prepare_inbound(self, msg_id: str, inbound) -> None:
"""Kick off Feishu side effects without delaying inbound dispatch.""" """Kick off Feishu side effects without delaying inbound dispatch."""
inbound = await self._attach_connection_identity(inbound)
reaction_task = asyncio.create_task(self._add_reaction(msg_id, "OK")) reaction_task = asyncio.create_task(self._add_reaction(msg_id, "OK"))
self._track_background_task(reaction_task, name="add_reaction", msg_id=msg_id) self._track_background_task(reaction_task, name="add_reaction", msg_id=msg_id)
self._ensure_running_card_started(msg_id) self._ensure_running_card_started(msg_id)
await self.bus.publish_inbound(inbound) await self.bus.publish_inbound(inbound)
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="feishu",
workspace_id=inbound.chat_id,
)
async def _bind_connection_from_connect_code(self, *, message_id: str, chat_id: str, user_id: str, code: str) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="feishu", state=code)
if state is None:
await self._reply_card(message_id, "Feishu connection code is invalid or expired.")
return True
if not user_id or not chat_id:
await self._reply_card(message_id, "Feishu connection could not be completed from this message.")
return True
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="feishu",
external_account_id=user_id,
workspace_id=chat_id,
metadata={
"chat_id": chat_id,
"message_id": message_id,
},
status="connected",
)
await self._reply_card(message_id, "Feishu connected to DeerFlow.")
return True
def _on_message(self, event) -> None: def _on_message(self, event) -> None:
"""Called by lark-oapi when a message is received (runs in lark thread).""" """Called by lark-oapi when a message is received (runs in lark thread)."""
try: try:
@@ -819,6 +878,23 @@ class FeishuChannel(Channel):
logger.info("[Feishu] empty text, ignoring message") logger.info("[Feishu] empty text, ignoring message")
return return
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(
self._bind_connection_from_connect_code(
message_id=msg_id,
chat_id=chat_id,
user_id=sender_id,
code=connect_code,
),
self._main_loop,
)
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
else:
logger.warning("[Feishu] main loop not running, cannot bind channel connection")
return
# Only treat known slash commands as commands; absolute paths and # Only treat known slash commands as commands; absolute paths and
# other slash-prefixed text should be handled as normal chat. # other slash-prefixed text should be handled as normal chat.
if _is_feishu_command(text): if _is_feishu_command(text):
+152 -30
View File
@@ -274,6 +274,22 @@ def _response_metadata(base_metadata: dict[str, Any], *, pending_clarification:
return metadata return metadata
def _thread_channel_metadata(msg: InboundMessage) -> dict[str, Any]:
channel_source: dict[str, Any] = {
"type": "im_channel",
"provider": msg.channel_name,
"chat_id": msg.chat_id,
}
if msg.topic_id:
channel_source["topic_id"] = msg.topic_id
if msg.thread_ts:
channel_source["thread_ts"] = msg.thread_ts
if msg.connection_id:
channel_source["connection_id"] = msg.connection_id
return {"channel_source": channel_source}
def _extract_text_content(content: Any) -> str: def _extract_text_content(content: Any) -> str:
"""Extract text from a streaming payload content field.""" """Extract text from a streaming payload content field."""
if isinstance(content, str): if isinstance(content, str):
@@ -440,6 +456,43 @@ def _human_input_message(content: str, *, original_content: str | None = None) -
return message return message
def _auth_disabled_owner_user_id() -> str | None:
try:
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled
except Exception:
logger.debug("Unable to inspect auth-disabled mode for channel owner fallback", exc_info=True)
return None
return AUTH_DISABLED_USER_ID if is_auth_disabled() else None
def _effective_owner_user_id(msg: InboundMessage) -> str | None:
return _auth_disabled_owner_user_id() or msg.owner_user_id
def _apply_effective_owner(msg: InboundMessage) -> InboundMessage:
owner_user_id = _effective_owner_user_id(msg)
if owner_user_id:
msg.owner_user_id = owner_user_id
return msg
def _owner_headers(msg: InboundMessage) -> dict[str, str] | None:
owner_user_id = _effective_owner_user_id(msg)
if not owner_user_id:
return None
return create_internal_auth_headers(owner_user_id=owner_user_id)
def _safe_user_id_for_run(raw_user_id: str) -> str:
from deerflow.config.paths import get_paths
try:
return get_paths().prepare_user_dir_for_raw_id(raw_user_id)
except Exception:
logger.exception("Failed to prepare channel run user directory")
return make_safe_user_id(raw_user_id)
def _resolve_slash_skill_command( def _resolve_slash_skill_command(
text: str, text: str,
available_skills: set[str] | None = None, available_skills: set[str] | None = None,
@@ -670,6 +723,7 @@ class ChannelManager:
assistant_id: str = DEFAULT_ASSISTANT_ID, assistant_id: str = DEFAULT_ASSISTANT_ID,
default_session: dict[str, Any] | None = None, default_session: dict[str, Any] | None = None,
channel_sessions: dict[str, Any] | None = None, channel_sessions: dict[str, Any] | None = None,
connection_repo: Any | None = None,
) -> None: ) -> None:
self.bus = bus self.bus = bus
self.store = store self.store = store
@@ -679,7 +733,9 @@ class ChannelManager:
self._assistant_id = assistant_id self._assistant_id = assistant_id
self._default_session = _as_dict(default_session) self._default_session = _as_dict(default_session)
self._channel_sessions = dict(channel_sessions or {}) self._channel_sessions = dict(channel_sessions or {})
self._connection_repo = connection_repo
self._client = None # lazy init — langgraph_sdk async client self._client = None # lazy init — langgraph_sdk async client
self._channel_metadata_synced: set[str] = set()
self._skill_storage: SkillStorage | None = None self._skill_storage: SkillStorage | None = None
self._csrf_token = generate_csrf_token() self._csrf_token = generate_csrf_token()
self._semaphore: asyncio.Semaphore | None = None self._semaphore: asyncio.Semaphore | None = None
@@ -728,12 +784,17 @@ class ChannelManager:
configurable["checkpoint_ns"] = "" configurable["checkpoint_ns"] = ""
configurable["thread_id"] = thread_id configurable["thread_id"] = thread_id
# ``user_id`` drives user-scoped filesystem buckets that only accept # ``user_id`` drives DeerFlow-owned memory, files, and thread buckets.
# ``[A-Za-z0-9_-]``, so normalize the channel id and keep the raw value # For browser-connected IM channels, prefer the DeerFlow account that
# under ``channel_user_id`` for platform-facing lookups. # owns the connection. Preserve the raw platform user under
# ``channel_user_id`` for platform-facing lookups and audits.
run_context_identity: dict[str, Any] = {"thread_id": thread_id} run_context_identity: dict[str, Any] = {"thread_id": thread_id}
owner_user_id = _effective_owner_user_id(msg)
if owner_user_id:
run_context_identity["user_id"] = _safe_user_id_for_run(owner_user_id)
elif msg.user_id:
run_context_identity["user_id"] = _safe_user_id_for_run(msg.user_id)
if msg.user_id: if msg.user_id:
run_context_identity["user_id"] = make_safe_user_id(msg.user_id)
run_context_identity["channel_user_id"] = msg.user_id run_context_identity["channel_user_id"] = msg.user_id
run_context = _merge_dicts( run_context = _merge_dicts(
@@ -845,6 +906,7 @@ class ChannelManager:
logger.error("[Manager] unhandled error in message task: %s", exc, exc_info=exc) logger.error("[Manager] unhandled error in message task: %s", exc, exc_info=exc)
async def _handle_message(self, msg: InboundMessage) -> None: async def _handle_message(self, msg: InboundMessage) -> None:
msg = _apply_effective_owner(msg)
async with self._semaphore: async with self._semaphore:
try: try:
if msg.msg_type == InboundMessageType.COMMAND: if msg.msg_type == InboundMessageType.COMMAND:
@@ -877,10 +939,27 @@ class ChannelManager:
# -- chat handling ----------------------------------------------------- # -- chat handling -----------------------------------------------------
async def _create_thread(self, client, msg: InboundMessage) -> str: async def _lookup_thread_id(self, msg: InboundMessage) -> str | None:
"""Create a new thread through Gateway and store the mapping.""" if msg.connection_id and self._connection_repo is not None:
thread = await client.threads.create() return await self._connection_repo.get_thread_id(
thread_id = thread["thread_id"] msg.connection_id,
msg.chat_id,
msg.topic_id,
)
return self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
async def _store_thread_id(self, msg: InboundMessage, thread_id: str) -> None:
if msg.connection_id and msg.owner_user_id and self._connection_repo is not None:
await self._connection_repo.set_thread_id(
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
provider=msg.channel_name,
external_conversation_id=msg.chat_id,
external_topic_id=msg.topic_id,
thread_id=thread_id,
)
return
self.store.set_thread_id( self.store.set_thread_id(
msg.channel_name, msg.channel_name,
msg.chat_id, msg.chat_id,
@@ -888,18 +967,49 @@ class ChannelManager:
topic_id=msg.topic_id, topic_id=msg.topic_id,
user_id=msg.user_id, user_id=msg.user_id,
) )
async def _create_thread(self, client, msg: InboundMessage) -> str:
"""Create a new thread through Gateway and store the mapping."""
metadata = _thread_channel_metadata(msg)
owner_headers = _owner_headers(msg)
if owner_headers:
thread = await client.threads.create(metadata=metadata, headers=owner_headers)
else:
thread = await client.threads.create(metadata=metadata)
thread_id = thread["thread_id"]
await self._store_thread_id(msg, thread_id)
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id) logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
return thread_id return thread_id
async def _update_thread_channel_metadata(self, client, msg: InboundMessage, thread_id: str) -> None:
"""Best-effort source metadata backfill for existing IM-created threads."""
# The metadata (provider/chat/topic) is constant for a thread, so one
# successful backfill per manager lifetime is enough — skip the
# redundant PATCH on every subsequent inbound message.
if thread_id in self._channel_metadata_synced:
return
update_kwargs: dict[str, Any] = {"metadata": _thread_channel_metadata(msg)}
if owner_headers := _owner_headers(msg):
update_kwargs["headers"] = owner_headers
try:
await client.threads.update(thread_id, **update_kwargs)
except Exception:
logger.debug("[Manager] failed to update channel metadata for thread_id=%s", thread_id, exc_info=True)
return
if len(self._channel_metadata_synced) > 4096:
self._channel_metadata_synced.clear()
self._channel_metadata_synced.add(thread_id)
async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None: async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None:
client = self._get_client() client = self._get_client()
# Look up existing DeerFlow thread. # Look up existing DeerFlow thread.
# topic_id may be None (e.g. Telegram private chats) — the store # topic_id may be None (e.g. Telegram private chats) — the store
# handles this by using the "channel:chat_id" key without a topic suffix. # handles this by using the "channel:chat_id" key without a topic suffix.
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) thread_id = await self._lookup_thread_id(msg)
if thread_id: if thread_id:
logger.info("[Manager] reusing thread: thread_id=%s for topic_id=%s", thread_id, msg.topic_id) logger.info("[Manager] reusing thread: thread_id=%s for topic_id=%s", thread_id, msg.topic_id)
await self._update_thread_channel_metadata(client, msg, thread_id)
# No existing thread found — create a new one # No existing thread found — create a new one
if thread_id is None: if thread_id is None:
@@ -940,14 +1050,19 @@ class ChannelManager:
return return
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100]) logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
run_kwargs: dict[str, Any] = {
"input": {"messages": [human_message]},
"config": run_config,
"context": run_context,
"multitask_strategy": "reject",
}
if owner_headers := _owner_headers(msg):
run_kwargs["headers"] = owner_headers
try: try:
result = await client.runs.wait( result = await client.runs.wait(
thread_id, thread_id,
assistant_id, assistant_id,
input={"messages": [human_message]}, **run_kwargs,
config=run_config,
context=run_context,
multitask_strategy="reject",
) )
except Exception as exc: except Exception as exc:
if _is_thread_busy_error(exc): if _is_thread_busy_error(exc):
@@ -984,6 +1099,8 @@ class ChannelManager:
artifacts=artifacts, artifacts=artifacts,
attachments=attachments, attachments=attachments,
thread_ts=msg.thread_ts, thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification), metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
) )
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id) logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
@@ -1008,16 +1125,21 @@ class ChannelManager:
last_published_text = "" last_published_text = ""
last_publish_at = 0.0 last_publish_at = 0.0
stream_error: BaseException | None = None stream_error: BaseException | None = None
stream_kwargs: dict[str, Any] = {
"input": {"messages": [human_message]},
"config": run_config,
"context": run_context,
"stream_mode": ["messages-tuple", "values"],
"multitask_strategy": "reject",
}
if owner_headers := _owner_headers(msg):
stream_kwargs["headers"] = owner_headers
try: try:
async for chunk in client.runs.stream( async for chunk in client.runs.stream(
thread_id, thread_id,
assistant_id, assistant_id,
input={"messages": [human_message]}, **stream_kwargs,
config=run_config,
context=run_context,
stream_mode=["messages-tuple", "values"],
multitask_strategy="reject",
): ):
event = getattr(chunk, "event", "") event = getattr(chunk, "event", "")
data = getattr(chunk, "data", None) data = getattr(chunk, "data", None)
@@ -1047,6 +1169,8 @@ class ChannelManager:
text=latest_text, text=latest_text,
is_final=False, is_final=False,
thread_ts=msg.thread_ts, thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_response_metadata(msg.metadata), metadata=_response_metadata(msg.metadata),
) )
) )
@@ -1093,6 +1217,8 @@ class ChannelManager:
attachments=attachments, attachments=attachments,
is_final=True, is_final=True,
thread_ts=msg.thread_ts, thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification), metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
) )
) )
@@ -1124,18 +1250,10 @@ class ChannelManager:
if reply is None and command == "new": if reply is None and command == "new":
# Create a new thread through Gateway # Create a new thread through Gateway
client = self._get_client() client = self._get_client()
thread = await client.threads.create() await self._create_thread(client, msg)
new_thread_id = thread["thread_id"]
self.store.set_thread_id(
msg.channel_name,
msg.chat_id,
new_thread_id,
topic_id=msg.topic_id,
user_id=msg.user_id,
)
reply = "New conversation started." reply = "New conversation started."
elif reply is None and command == "status": elif reply is None and command == "status":
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) thread_id = await self._lookup_thread_id(msg)
reply = f"Active thread: {thread_id}" if thread_id else "No active conversation." reply = f"Active thread: {thread_id}" if thread_id else "No active conversation."
elif reply is None and command == "models": elif reply is None and command == "models":
reply = await self._fetch_gateway("/api/models", "models") reply = await self._fetch_gateway("/api/models", "models")
@@ -1174,9 +1292,11 @@ class ChannelManager:
outbound = OutboundMessage( outbound = OutboundMessage(
channel_name=msg.channel_name, channel_name=msg.channel_name,
chat_id=msg.chat_id, chat_id=msg.chat_id,
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "", thread_id=await self._lookup_thread_id(msg) or "",
text=reply, text=reply,
thread_ts=msg.thread_ts, thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_slim_metadata(msg.metadata), metadata=_slim_metadata(msg.metadata),
) )
await self.bus.publish_outbound(outbound) await self.bus.publish_outbound(outbound)
@@ -1212,9 +1332,11 @@ class ChannelManager:
outbound = OutboundMessage( outbound = OutboundMessage(
channel_name=msg.channel_name, channel_name=msg.channel_name,
chat_id=msg.chat_id, chat_id=msg.chat_id,
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "", thread_id=await self._lookup_thread_id(msg) or "",
text=error_text, text=error_text,
thread_ts=msg.thread_ts, thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_slim_metadata(msg.metadata), metadata=_slim_metadata(msg.metadata),
) )
await self.bus.publish_outbound(outbound) await self.bus.publish_outbound(outbound)
+14
View File
@@ -44,6 +44,12 @@ class InboundMessage:
Messages sharing the same ``topic_id`` within a ``chat_id`` will Messages sharing the same ``topic_id`` within a ``chat_id`` will
reuse the same DeerFlow thread. When ``None``, each message reuse the same DeerFlow thread. When ``None``, each message
creates a new thread (one-shot Q&A). creates a new thread (one-shot Q&A).
connection_id: Optional DeerFlow channel connection id. When present,
conversation mapping is scoped by the connection instead of the
legacy global ``channel_name:chat_id[:topic_id]`` key.
owner_user_id: DeerFlow user id that owns the channel connection.
Platform user ids stay in ``user_id``.
workspace_id: Optional external workspace/guild/team id.
files: Optional list of file attachments (platform-specific dicts). files: Optional list of file attachments (platform-specific dicts).
metadata: Arbitrary extra data from the channel. metadata: Arbitrary extra data from the channel.
created_at: Unix timestamp when the message was created. created_at: Unix timestamp when the message was created.
@@ -56,6 +62,9 @@ class InboundMessage:
msg_type: InboundMessageType = InboundMessageType.CHAT msg_type: InboundMessageType = InboundMessageType.CHAT
thread_ts: str | None = None thread_ts: str | None = None
topic_id: str | None = None topic_id: str | None = None
connection_id: str | None = None
owner_user_id: str | None = None
workspace_id: str | None = None
files: list[dict[str, Any]] = field(default_factory=list) files: list[dict[str, Any]] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict)
created_at: float = field(default_factory=time.time) created_at: float = field(default_factory=time.time)
@@ -95,6 +104,9 @@ class OutboundMessage:
is_final: Whether this is the final message in the response stream. is_final: Whether this is the final message in the response stream.
thread_ts: Optional platform thread identifier for threaded replies. thread_ts: Optional platform thread identifier for threaded replies.
metadata: Arbitrary extra data. metadata: Arbitrary extra data.
connection_id: Optional DeerFlow channel connection id used for
connection-specific outbound credentials.
owner_user_id: DeerFlow user id that owns the channel connection.
created_at: Unix timestamp. created_at: Unix timestamp.
""" """
@@ -106,6 +118,8 @@ class OutboundMessage:
attachments: list[ResolvedAttachment] = field(default_factory=list) attachments: list[ResolvedAttachment] = field(default_factory=list)
is_final: bool = True is_final: bool = True
thread_ts: str | None = None thread_ts: str | None = None
connection_id: str | None = None
owner_user_id: str | None = None
metadata: dict[str, Any] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict)
created_at: float = field(default_factory=time.time) created_at: float = field(default_factory=time.time)
@@ -0,0 +1,154 @@
"""Local persistence for runtime IM channel configuration."""
from __future__ import annotations
import json
import logging
import tempfile
import threading
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
RUNTIME_CHANNEL_DISABLED_FLAG = "_runtime_disabled"
class ChannelRuntimeConfigStore:
"""JSON-backed store for channel credentials entered from the UI.
This intentionally mirrors ``ChannelStore``: local/private deployments get
durable runtime configuration without needing a public callback URL or a
config.yaml edit.
"""
def __init__(self, path: str | Path | None = None) -> None:
if path is None:
from deerflow.config.paths import get_paths
path = Path(get_paths().base_dir) / "channels" / "runtime-config.json"
self._path = Path(path)
self._path.parent.mkdir(parents=True, exist_ok=True)
self._data: dict[str, dict[str, Any]] = self._load()
self._lock = threading.Lock()
def _load(self) -> dict[str, dict[str, Any]]:
if self._path.exists():
try:
raw = json.loads(self._path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
logger.warning("Corrupt channel runtime config store at %s, starting fresh", self._path)
return {}
if isinstance(raw, dict):
return {str(name): dict(value) for name, value in raw.items() if isinstance(value, dict)}
return {}
def _save(self) -> None:
fd = tempfile.NamedTemporaryFile(
mode="w",
dir=self._path.parent,
suffix=".tmp",
delete=False,
)
try:
json.dump(self._data, fd, indent=2, ensure_ascii=False)
fd.close()
Path(fd.name).replace(self._path)
try:
self._path.chmod(0o600)
except OSError:
logger.debug("Unable to chmod channel runtime config store at %s", self._path, exc_info=True)
except BaseException:
fd.close()
Path(fd.name).unlink(missing_ok=True)
raise
def load_all(self) -> dict[str, dict[str, Any]]:
with self._lock:
return {name: dict(config) for name, config in self._data.items()}
def get_provider_config(self, provider: str) -> dict[str, Any] | None:
with self._lock:
config = self._data.get(provider)
return dict(config) if isinstance(config, dict) else None
def set_provider_config(self, provider: str, config: dict[str, Any]) -> None:
with self._lock:
self._data[provider] = dict(config)
self._save()
def set_provider_disconnected(self, provider: str) -> None:
with self._lock:
self._data[provider] = {
"enabled": False,
RUNTIME_CHANNEL_DISABLED_FLAG: True,
}
self._save()
def remove_provider_config(self, provider: str) -> bool:
with self._lock:
if provider not in self._data:
return False
del self._data[provider]
self._save()
return True
def _provider_enabled(channel_connections_config: Any, provider: str) -> bool:
provider_config = getattr(channel_connections_config, provider, None)
return bool(getattr(provider_config, "enabled", False))
def _runtime_channel_disconnected(runtime_config: dict[str, Any]) -> bool:
return runtime_config.get(RUNTIME_CHANNEL_DISABLED_FLAG) is True and runtime_config.get("enabled") is False
def merge_runtime_channel_configs(
channels_config: dict[str, Any],
channel_connections_config: Any,
*,
store: ChannelRuntimeConfigStore | None = None,
) -> None:
"""Merge persisted runtime provider config into ``channels_config`` in-place."""
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
return
runtime_store = store or ChannelRuntimeConfigStore()
for provider, runtime_config in runtime_store.load_all().items():
if not _provider_enabled(channel_connections_config, provider):
continue
if _runtime_channel_disconnected(runtime_config):
channels_config.pop(provider, None)
continue
existing = channels_config.get(provider)
merged = dict(runtime_config)
if isinstance(existing, dict):
merged.update(existing)
channels_config[provider] = merged
def apply_runtime_connection_config(
channel_connections_config: Any,
*,
store: ChannelRuntimeConfigStore | None = None,
) -> Any:
"""Apply persisted connection metadata that lives outside ``channels``.
Telegram uses a bot username for deep links; UI-entered values are stored
with the runtime channel config so local restarts keep the provider
configured.
"""
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
return channel_connections_config
runtime_store = store or ChannelRuntimeConfigStore()
telegram_runtime_config = runtime_store.get_provider_config("telegram")
bot_username = ""
if isinstance(telegram_runtime_config, dict):
bot_username = str(telegram_runtime_config.get("bot_username") or "").strip()
if not bot_username or not _provider_enabled(channel_connections_config, "telegram"):
return channel_connections_config
config = channel_connections_config.model_copy(deep=True)
config.telegram.bot_username = bot_username
return config
+145 -26
View File
@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import os import os
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@@ -9,6 +10,7 @@ from typing import TYPE_CHECKING, Any
from app.channels.base import Channel from app.channels.base import Channel
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
from app.channels.message_bus import MessageBus from app.channels.message_bus import MessageBus
from app.channels.runtime_config_store import merge_runtime_channel_configs
from app.channels.store import ChannelStore from app.channels.store import ChannelStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -42,6 +44,11 @@ _CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL" _CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
def _channel_has_credentials(name: str, channel_config: dict[str, Any]) -> bool:
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
return any(not isinstance(channel_config.get(key), bool) and channel_config.get(key) is not None and str(channel_config[key]).strip() for key in cred_keys)
def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str, default: str) -> str: def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str, default: str) -> str:
value = config.pop(config_key, None) value = config.pop(config_key, None)
if isinstance(value, str) and value.strip(): if isinstance(value, str) and value.strip():
@@ -52,6 +59,30 @@ def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str,
return default return default
def _merge_channel_connection_runtime_config(channels_config: dict[str, Any], app_config: AppConfig) -> None:
connection_config = getattr(app_config, "channel_connections", None)
merge_runtime_channel_configs(channels_config, connection_config)
def _make_connection_repo(app_config: AppConfig):
connection_config = getattr(app_config, "channel_connections", None)
if connection_config is None or not getattr(connection_config, "enabled", False):
return None
try:
from deerflow.persistence.channel_connections import ChannelConnectionRepository
from deerflow.persistence.engine import get_session_factory
except Exception:
logger.exception("Failed to import channel connection repository")
return None
session_factory = get_session_factory()
if session_factory is None:
logger.warning("Channel connections are enabled but database persistence is not available")
return None
return ChannelConnectionRepository(session_factory)
class ChannelService: class ChannelService:
"""Manages the lifecycle of all configured IM channels. """Manages the lifecycle of all configured IM channels.
@@ -59,9 +90,10 @@ class ChannelService:
instantiates enabled channels, and starts the ChannelManager dispatcher. instantiates enabled channels, and starts the ChannelManager dispatcher.
""" """
def __init__(self, channels_config: dict[str, Any] | None = None) -> None: def __init__(self, channels_config: dict[str, Any] | None = None, *, connection_repo: Any | None = None) -> None:
self.bus = MessageBus() self.bus = MessageBus()
self.store = ChannelStore() self.store = ChannelStore()
self._connection_repo = connection_repo
config = dict(channels_config or {}) config = dict(channels_config or {})
langgraph_url = _resolve_service_url(config, "langgraph_url", _CHANNELS_LANGGRAPH_URL_ENV, DEFAULT_LANGGRAPH_URL) langgraph_url = _resolve_service_url(config, "langgraph_url", _CHANNELS_LANGGRAPH_URL_ENV, DEFAULT_LANGGRAPH_URL)
gateway_url = _resolve_service_url(config, "gateway_url", _CHANNELS_GATEWAY_URL_ENV, DEFAULT_GATEWAY_URL) gateway_url = _resolve_service_url(config, "gateway_url", _CHANNELS_GATEWAY_URL_ENV, DEFAULT_GATEWAY_URL)
@@ -74,10 +106,12 @@ class ChannelService:
gateway_url=gateway_url, gateway_url=gateway_url,
default_session=default_session if isinstance(default_session, dict) else None, default_session=default_session if isinstance(default_session, dict) else None,
channel_sessions=channel_sessions, channel_sessions=channel_sessions,
connection_repo=connection_repo,
) )
self._channels: dict[str, Any] = {} # name -> Channel instance self._channels: dict[str, Any] = {} # name -> Channel instance
self._config = config self._config = config
self._running = False self._running = False
self._readiness_locks: dict[str, asyncio.Lock] = {}
@classmethod @classmethod
def from_app_config(cls, app_config: AppConfig | None = None) -> ChannelService: def from_app_config(cls, app_config: AppConfig | None = None) -> ChannelService:
@@ -90,8 +124,9 @@ class ChannelService:
# extra fields are allowed by AppConfig (extra="allow") # extra fields are allowed by AppConfig (extra="allow")
extra = app_config.model_extra or {} extra = app_config.model_extra or {}
if "channels" in extra: if "channels" in extra:
channels_config = extra["channels"] channels_config = dict(extra["channels"] or {})
return cls(channels_config=channels_config) _merge_channel_connection_runtime_config(channels_config, app_config)
return cls(channels_config=channels_config, connection_repo=_make_connection_repo(app_config))
async def start(self) -> None: async def start(self) -> None:
"""Start the manager and all enabled channels.""" """Start the manager and all enabled channels."""
@@ -99,36 +134,83 @@ class ChannelService:
return return
await self.manager.start() await self.manager.start()
self._running = True
ready_status = await self.ensure_ready_channels(attempts=2)
ready_count = sum(1 for ready in ready_status.values() if ready)
logger.info("ChannelService started with %d/%d ready channels", ready_count, len(ready_status))
async def ensure_ready_channels(self, *, attempts: int = 1) -> dict[str, bool]:
"""Start or restart enabled configured channels that are not ready."""
ready_status: dict[str, bool] = {}
for name, channel_config in self._config.items(): for name, channel_config in self._config.items():
if not isinstance(channel_config, dict): if not isinstance(channel_config, dict):
continue continue
if not channel_config.get("enabled", False): if not channel_config.get("enabled", False):
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, []) if _channel_has_credentials(name, channel_config):
has_creds = any(not isinstance(channel_config.get(k), bool) and channel_config.get(k) is not None and str(channel_config[k]).strip() for k in cred_keys)
if has_creds:
logger.warning( logger.warning(
"Channel '%s' has credentials configured but is disabled. Set enabled: true under channels.%s in config.yaml to activate it.", "A configured channel has credentials configured but is disabled. Set enabled: true under its channels entry in config.yaml to activate it.",
name,
name,
) )
else: else:
logger.info("Channel %s is disabled, skipping", name) logger.info("A configured channel is disabled, skipping")
continue continue
await self._start_channel(name, channel_config) ready_status[name] = await self.ensure_channel_ready(name, attempts=attempts)
return ready_status
self._running = True async def ensure_channel_ready(
logger.info("ChannelService started with channels: %s", list(self._channels.keys())) self,
name: str,
config: dict[str, Any] | None = None,
*,
attempts: int = 1,
) -> bool:
"""Ensure a single enabled channel is running using its current config."""
if not self._running:
logger.warning("ChannelService is not running; cannot ensure channel readiness")
return False
if config is not None:
self._config[name] = dict(config)
# Serialize per channel: readiness is polled from request handlers, so
# concurrent calls must not stop/start the same channel worker twice.
lock = self._readiness_locks.setdefault(name, asyncio.Lock())
async with lock:
channel_config = self._config.get(name)
if not channel_config or not isinstance(channel_config, dict):
logger.warning("No config for requested channel")
return False
if not channel_config.get("enabled", False):
return False
channel = self._channels.get(name)
if channel is not None and channel.is_running:
return True
if channel is not None:
try:
await channel.stop()
except Exception:
logger.exception("Error stopping non-running channel before readiness retry")
self._channels.pop(name, None)
max_attempts = max(1, attempts)
for attempt in range(max_attempts):
if attempt > 0:
logger.info("Retrying channel startup after readiness check")
if await self._start_channel(name, channel_config):
return True
return False
async def stop(self) -> None: async def stop(self) -> None:
"""Stop all channels and the manager.""" """Stop all channels and the manager."""
for name, channel in list(self._channels.items()): for name, channel in list(self._channels.items()):
try: try:
await channel.stop() await channel.stop()
logger.info("Channel %s stopped", name) logger.info("Channel stopped")
except Exception: except Exception:
logger.exception("Error stopping channel %s", name) logger.exception("Error stopping channel")
self._channels.clear() self._channels.clear()
await self.manager.stop() await self.manager.stop()
@@ -140,6 +222,9 @@ class ChannelService:
Uses ``get_app_config()`` which detects file changes via mtime, Uses ``get_app_config()`` which detects file changes via mtime,
so edits to ``config.yaml`` are picked up without a process restart. so edits to ``config.yaml`` are picked up without a process restart.
The UI runtime-config overlay applied at startup is re-applied here
so a file-driven reload neither drops credentials entered from the
browser nor resurrects a channel disconnected from it.
Falls back to the cached ``self._config`` when config loading fails. Falls back to the cached ``self._config`` when config loading fails.
""" """
try: try:
@@ -147,7 +232,8 @@ class ChannelService:
app_config = get_app_config() app_config = get_app_config()
extra = app_config.model_extra or {} extra = app_config.model_extra or {}
channels_config = extra.get("channels", {}) channels_config = dict(extra.get("channels") or {})
_merge_channel_connection_runtime_config(channels_config, app_config)
channel_config = channels_config.get(name) channel_config = channels_config.get(name)
if isinstance(channel_config, dict): if isinstance(channel_config, dict):
# Update the cached config so get_status() stays consistent. # Update the cached config so get_status() stays consistent.
@@ -157,18 +243,23 @@ class ChannelService:
logger.exception("Failed to reload config for channel %s, using cached version", name) logger.exception("Failed to reload config for channel %s, using cached version", name)
return self._config.get(name) return self._config.get(name)
async def restart_channel(self, name: str) -> bool: async def restart_channel(self, name: str, *, reload_config: bool = True) -> bool:
"""Restart a specific channel. Returns True if successful.""" """Restart a specific channel. Returns True if successful."""
if name in self._channels: if name in self._channels:
try: try:
await self._channels[name].stop() await self._channels[name].stop()
except Exception: except Exception:
logger.exception("Error stopping channel %s for restart", name) logger.exception("Error stopping channel for restart")
del self._channels[name] del self._channels[name]
config = self._load_channel_config(name) if reload_config:
# Reading config.yaml and the runtime store is disk IO; keep it
# off the event loop.
config = await asyncio.to_thread(self._load_channel_config, name)
else:
config = self._config.get(name)
if not config or not isinstance(config, dict): if not config or not isinstance(config, dict):
logger.warning("No config for channel %s", name) logger.warning("No config for requested channel")
return False return False
if not config.get("enabled", False): if not config.get("enabled", False):
@@ -177,11 +268,35 @@ class ChannelService:
return await self._start_channel(name, config) return await self._start_channel(name, config)
async def configure_channel(self, name: str, config: dict[str, Any]) -> bool:
"""Apply runtime config for a channel and restart it if the service is running."""
self._config[name] = dict(config)
if not self._running:
return True
# The caller just supplied the authoritative config (e.g. credentials
# entered in the browser that are never written to config.yaml) — a
# file reload here would clobber it with the stale on-disk entry.
return await self.restart_channel(name, reload_config=False)
async def remove_channel(self, name: str) -> bool:
"""Remove runtime config for a channel and stop it if currently running."""
self._config.pop(name, None)
channel = self._channels.pop(name, None)
if channel is None:
return True
try:
await channel.stop()
logger.info("Channel stopped and removed")
return True
except Exception:
logger.exception("Error stopping channel for removal")
return False
async def _start_channel(self, name: str, config: dict[str, Any]) -> bool: async def _start_channel(self, name: str, config: dict[str, Any]) -> bool:
"""Instantiate and start a single channel.""" """Instantiate and start a single channel."""
import_path = _CHANNEL_REGISTRY.get(name) import_path = _CHANNEL_REGISTRY.get(name)
if not import_path: if not import_path:
logger.warning("Unknown channel type: %s", name) logger.warning("Unknown channel type")
return False return False
try: try:
@@ -189,24 +304,26 @@ class ChannelService:
channel_cls = resolve_class(import_path, base_class=None) channel_cls = resolve_class(import_path, base_class=None)
except Exception: except Exception:
logger.exception("Failed to import channel class for %s", name) logger.exception("Failed to import channel class")
return False return False
try: try:
config = dict(config) config = dict(config)
config["channel_store"] = self.store config["channel_store"] = self.store
if self._connection_repo is not None:
config["connection_repo"] = self._connection_repo
channel = channel_cls(bus=self.bus, config=config) channel = channel_cls(bus=self.bus, config=config)
self._channels[name] = channel self._channels[name] = channel
await channel.start() await channel.start()
if not channel.is_running: if not channel.is_running:
self._channels.pop(name, None) self._channels.pop(name, None)
logger.error("Channel %s did not enter a running state after start()", name) logger.error("Channel did not enter a running state after start()")
return False return False
logger.info("Channel %s started", name) logger.info("Channel started")
return True return True
except Exception: except Exception:
self._channels.pop(name, None) self._channels.pop(name, None)
logger.exception("Failed to start channel %s", name) logger.exception("Failed to start channel")
return False return False
def get_status(self) -> dict[str, Any]: def get_status(self) -> dict[str, Any]:
@@ -245,7 +362,9 @@ async def start_channel_service(app_config: AppConfig | None = None) -> ChannelS
global _channel_service global _channel_service
if _channel_service is not None: if _channel_service is not None:
return _channel_service return _channel_service
_channel_service = ChannelService.from_app_config(app_config) # from_app_config reads the JSON channel store and runtime config files;
# keep that disk IO off the event loop.
_channel_service = await asyncio.to_thread(ChannelService.from_app_config, app_config)
await _channel_service.start() await _channel_service.start()
return _channel_service return _channel_service
+148 -26
View File
@@ -9,7 +9,8 @@ from typing import Any
from markdown_to_mrkdwn import SlackMarkdownConverter from markdown_to_mrkdwn import SlackMarkdownConverter
from app.channels.base import Channel from app.channels.base import Channel
from app.channels.commands import is_known_channel_command from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -64,6 +65,9 @@ class SlackChannel(Channel):
self._web_client = None self._web_client = None
self._loop: asyncio.AbstractEventLoop | None = None self._loop: asyncio.AbstractEventLoop | None = None
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", [])) self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
self._connection_repo = config.get("connection_repo")
self._web_client_factory = config.get("web_client_factory")
self._connection_web_clients: dict[str, tuple[str, Any]] = {}
configured_bot_user_id = config.get("bot_user_id") configured_bot_user_id = config.get("bot_user_id")
self._bot_user_id = str(configured_bot_user_id).lstrip("@") if configured_bot_user_id else None self._bot_user_id = str(configured_bot_user_id).lstrip("@") if configured_bot_user_id else None
@@ -80,26 +84,28 @@ class SlackChannel(Channel):
return return
self._SocketModeResponse = SocketModeResponse self._SocketModeResponse = SocketModeResponse
if self._web_client_factory is None:
self._web_client_factory = WebClient
bot_token = self.config.get("bot_token", "") bot_token = self.config.get("bot_token", "")
app_token = self.config.get("app_token", "") app_token = self.config.get("app_token", "")
if self._connection_repo is not None and self.config.get("event_delivery") == "http":
if not bot_token:
logger.error("Slack HTTP Events mode requires bot_token")
return
await self._initialize_operator_web_client(str(bot_token))
self._loop = asyncio.get_event_loop()
self._running = True
self.bus.subscribe_outbound(self._on_outbound)
logger.info("Slack channel started in HTTP Events mode")
return
if not bot_token or not app_token: if not bot_token or not app_token:
logger.error("Slack channel requires bot_token and app_token") logger.error("Slack channel requires bot_token and app_token")
return return
self._web_client = WebClient(token=bot_token) await self._initialize_operator_web_client(str(bot_token))
if self._bot_user_id is None:
try:
auth_info = await asyncio.to_thread(self._web_client.auth_test)
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
if user_id is None:
auth_get = getattr(auth_info, "get", None)
user_id = auth_get("user_id") if callable(auth_get) else None
if isinstance(user_id, str) and user_id:
self._bot_user_id = user_id
except Exception:
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
self._socket_client = SocketModeClient( self._socket_client = SocketModeClient(
app_token=app_token, app_token=app_token,
web_client=self._web_client, web_client=self._web_client,
@@ -124,7 +130,8 @@ class SlackChannel(Channel):
logger.info("Slack channel stopped") logger.info("Slack channel stopped")
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None: async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
if not self._web_client: web_client = await self._get_web_client_for_message(msg)
if not web_client:
return return
kwargs: dict[str, Any] = { kwargs: dict[str, Any] = {
@@ -137,11 +144,12 @@ class SlackChannel(Channel):
last_exc: Exception | None = None last_exc: Exception | None = None
for attempt in range(_max_retries): for attempt in range(_max_retries):
try: try:
await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs) await asyncio.to_thread(web_client.chat_postMessage, **kwargs)
# Add a completion reaction to the thread root # Add a completion reaction to the thread root
if msg.thread_ts: if msg.thread_ts:
await asyncio.to_thread( await asyncio.to_thread(
self._add_reaction, self._add_reaction_with_client,
web_client,
msg.chat_id, msg.chat_id,
msg.thread_ts, msg.thread_ts,
"white_check_mark", "white_check_mark",
@@ -165,7 +173,8 @@ class SlackChannel(Channel):
if msg.thread_ts: if msg.thread_ts:
try: try:
await asyncio.to_thread( await asyncio.to_thread(
self._add_reaction, self._add_reaction_with_client,
web_client,
msg.chat_id, msg.chat_id,
msg.thread_ts, msg.thread_ts,
"x", "x",
@@ -177,7 +186,8 @@ class SlackChannel(Channel):
raise last_exc raise last_exc
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool: async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
if not self._web_client: web_client = await self._get_web_client_for_message(msg)
if not web_client:
return False return False
try: try:
@@ -190,7 +200,7 @@ class SlackChannel(Channel):
if msg.thread_ts: if msg.thread_ts:
kwargs["thread_ts"] = msg.thread_ts kwargs["thread_ts"] = msg.thread_ts
await asyncio.to_thread(self._web_client.files_upload_v2, **kwargs) await asyncio.to_thread(web_client.files_upload_v2, **kwargs)
logger.info("[Slack] file uploaded: %s to channel=%s", attachment.filename, msg.chat_id) logger.info("[Slack] file uploaded: %s to channel=%s", attachment.filename, msg.chat_id)
return True return True
except Exception: except Exception:
@@ -199,12 +209,45 @@ class SlackChannel(Channel):
# -- internal ---------------------------------------------------------- # -- internal ----------------------------------------------------------
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None: async def _initialize_operator_web_client(self, bot_token: str) -> None:
"""Add an emoji reaction to a message (best-effort, non-blocking).""" self._web_client = self._web_client_factory(token=bot_token)
if not self._web_client: if self._bot_user_id is not None:
return return
try: try:
self._web_client.reactions_add( auth_info = await asyncio.to_thread(self._web_client.auth_test)
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
if user_id is None:
auth_get = getattr(auth_info, "get", None)
user_id = auth_get("user_id") if callable(auth_get) else None
if isinstance(user_id, str) and user_id:
self._bot_user_id = user_id
except Exception:
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
async def _get_web_client_for_message(self, msg: OutboundMessage):
if msg.connection_id and self._connection_repo is not None:
credentials = await self._connection_repo.get_credentials(msg.connection_id)
access_token = credentials.get("access_token") if credentials else None
if not access_token:
return self._web_client
# WebClient keeps its own HTTP session and rate-limit state, so
# reuse one per connection until its token changes.
cached = self._connection_web_clients.get(msg.connection_id)
if cached is not None and cached[0] == access_token:
return cached[1]
if self._web_client_factory is None:
from slack_sdk import WebClient
self._web_client_factory = WebClient
web_client = self._web_client_factory(token=access_token)
self._connection_web_clients[msg.connection_id] = (access_token, web_client)
return web_client
return self._web_client
@staticmethod
def _add_reaction_with_client(web_client, channel_id: str, timestamp: str, emoji: str) -> None:
try:
web_client.reactions_add(
channel=channel_id, channel=channel_id,
timestamp=timestamp, timestamp=timestamp,
name=emoji, name=emoji,
@@ -213,6 +256,12 @@ class SlackChannel(Channel):
if "already_reacted" not in str(exc): if "already_reacted" not in str(exc):
logger.warning("[Slack] failed to add reaction %s: %s", emoji, exc) logger.warning("[Slack] failed to add reaction %s: %s", emoji, exc)
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
if not self._web_client:
return
self._add_reaction_with_client(self._web_client, channel_id, timestamp, emoji)
def _send_running_reply(self, channel_id: str, thread_ts: str) -> None: def _send_running_reply(self, channel_id: str, thread_ts: str) -> None:
"""Send a 'Working on it......' reply in the thread (called from SDK thread).""" """Send a 'Working on it......' reply in the thread (called from SDK thread)."""
if not self._web_client: if not self._web_client:
@@ -249,12 +298,15 @@ class SlackChannel(Channel):
# Handle message events (DM or @mention) # Handle message events (DM or @mention)
if etype in ("message", "app_mention"): if etype in ("message", "app_mention"):
self._handle_message_event(event) self._handle_message_event(
event,
team_id=req.payload.get("team_id") or req.payload.get("team") or event.get("team"),
)
except Exception: except Exception:
logger.exception("Error processing Slack event") logger.exception("Error processing Slack event")
def _handle_message_event(self, event: dict) -> None: def _handle_message_event(self, event: dict, *, team_id: str | None = None) -> None:
# Ignore bot messages # Ignore bot messages
if event.get("bot_id") or event.get("subtype"): if event.get("bot_id") or event.get("subtype"):
return return
@@ -272,6 +324,19 @@ class SlackChannel(Channel):
if not text: if not text:
return return
connect_code = extract_connect_code(text)
if connect_code:
if self._loop and self._loop.is_running():
asyncio.run_coroutine_threadsafe(
self._bind_connection_from_connect_code(
event=event,
team_id=str(team_id or event.get("team") or ""),
code=connect_code,
),
self._loop,
)
return
channel_id = event.get("channel", "") channel_id = event.get("channel", "")
thread_ts = event.get("thread_ts") or event.get("ts", "") thread_ts = event.get("thread_ts") or event.get("ts", "")
@@ -297,4 +362,61 @@ class SlackChannel(Channel):
self._add_reaction(channel_id, event.get("ts", thread_ts), "eyes") self._add_reaction(channel_id, event.get("ts", thread_ts), "eyes")
# Send "running" reply first (fire-and-forget from SDK thread) # Send "running" reply first (fire-and-forget from SDK thread)
self._send_running_reply(channel_id, thread_ts) self._send_running_reply(channel_id, thread_ts)
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop) if self._connection_repo is None:
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
else:
asyncio.run_coroutine_threadsafe(self._publish_inbound_with_connection(inbound, team_id=team_id), self._loop)
async def _publish_inbound_with_connection(self, inbound, *, team_id: str | None = None) -> None:
inbound = await self._attach_connection_identity(inbound, team_id=team_id)
await self.bus.publish_inbound(inbound)
async def _attach_connection_identity(self, inbound, *, team_id: str | None = None):
workspace_id = str(team_id or inbound.metadata.get("team_id") or "")
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="slack",
workspace_id=workspace_id,
)
async def _bind_connection_from_connect_code(self, *, event: dict, team_id: str, code: str) -> bool:
if self._connection_repo is None or not code:
return False
channel_id = str(event.get("channel") or "")
thread_ts = str(event.get("thread_ts") or event.get("ts") or "")
state = await self._connection_repo.consume_oauth_state(provider="slack", state=code)
if state is None:
await self._post_connection_reply(channel_id, "Slack connection code is invalid or expired.", thread_ts)
return True
user_id = str(event.get("user") or "")
if not user_id or not team_id:
await self._post_connection_reply(channel_id, "Slack connection could not be completed from this message.", thread_ts)
return True
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="slack",
external_account_id=user_id,
workspace_id=team_id,
metadata={
"team_id": team_id,
"channel_id": channel_id,
},
status="connected",
)
await self._post_connection_reply(channel_id, "Slack connected to DeerFlow.", thread_ts)
return True
async def _post_connection_reply(self, channel_id: str, text: str, thread_ts: str | None = None) -> None:
if not self._web_client or not channel_id:
return
kwargs: dict[str, Any] = {"channel": channel_id, "text": text}
if thread_ts:
kwargs["thread_ts"] = thread_ts
try:
await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs)
except Exception:
logger.exception("[Slack] failed to send connection reply in channel=%s", channel_id)
+57
View File
@@ -8,6 +8,7 @@ import threading
from typing import Any from typing import Any
from app.channels.base import Channel from app.channels.base import Channel
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -35,6 +36,7 @@ class TelegramChannel(Channel):
pass pass
# chat_id -> last sent message_id for threaded replies # chat_id -> last sent message_id for threaded replies
self._last_bot_message: dict[str, int] = {} self._last_bot_message: dict[str, int] = {}
self._connection_repo = config.get("connection_repo")
async def start(self) -> None: async def start(self) -> None:
if self._running: if self._running:
@@ -233,6 +235,54 @@ class TelegramChannel(Channel):
return True return True
return user_id in self._allowed_users return user_id in self._allowed_users
@staticmethod
def _telegram_display_name(user) -> str:
full_name = getattr(user, "full_name", None)
if isinstance(full_name, str) and full_name:
return full_name
username = getattr(user, "username", None)
if isinstance(username, str) and username:
return username
return str(getattr(user, "id", ""))
async def _bind_connection_from_start_token(self, update, state_token: str) -> bool:
if self._connection_repo is None or not state_token:
return False
state = await self._connection_repo.consume_oauth_state(provider="telegram", state=state_token)
if state is None:
await update.message.reply_text("Telegram connection link is invalid or expired.")
return True
owner_user_id = state["owner_user_id"]
user_id = str(update.effective_user.id)
chat_id = str(update.effective_chat.id)
connection = await self._connection_repo.upsert_connection(
owner_user_id=owner_user_id,
provider="telegram",
external_account_id=user_id,
external_account_name=self._telegram_display_name(update.effective_user),
workspace_id=chat_id,
workspace_name=None,
metadata={
"chat_id": chat_id,
"chat_type": update.effective_chat.type,
"telegram_username": getattr(update.effective_user, "username", None),
},
status="connected",
)
logger.info("[Telegram] bound chat=%s user=%s to DeerFlow user=%s connection=%s", chat_id, user_id, owner_user_id, connection["id"])
await update.message.reply_text("Telegram connected to DeerFlow.")
return True
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="telegram",
workspace_id=inbound.chat_id,
)
def _get_bot_username(self, context) -> str | None: def _get_bot_username(self, context) -> str | None:
bot = getattr(context, "bot", None) bot = getattr(context, "bot", None)
username = getattr(bot, "username", None) username = getattr(bot, "username", None)
@@ -264,6 +314,11 @@ class TelegramChannel(Channel):
"""Handle /start command.""" """Handle /start command."""
if not self._check_user(update.effective_user.id): if not self._check_user(update.effective_user.id):
return return
args = getattr(context, "args", []) if context is not None else []
if args:
handled = await self._bind_connection_from_start_token(update, str(args[0]))
if handled:
return
await update.message.reply_text("Welcome to DeerFlow! Send me a message to start a conversation.\nType /help for available commands.") await update.message.reply_text("Welcome to DeerFlow! Send me a message to start a conversation.\nType /help for available commands.")
async def _process_incoming_with_reply(self, chat_id: str, msg_id: int, inbound: InboundMessage) -> None: async def _process_incoming_with_reply(self, chat_id: str, msg_id: int, inbound: InboundMessage) -> None:
@@ -299,6 +354,7 @@ class TelegramChannel(Channel):
thread_ts=msg_id, thread_ts=msg_id,
) )
inbound.topic_id = topic_id inbound.topic_id = topic_id
inbound = await self._attach_connection_identity(inbound)
if self._main_loop and self._main_loop.is_running(): if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop) fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
@@ -341,6 +397,7 @@ class TelegramChannel(Channel):
thread_ts=msg_id, thread_ts=msg_id,
) )
inbound.topic_id = topic_id inbound.topic_id = topic_id
inbound = await self._attach_connection_identity(inbound)
if self._main_loop and self._main_loop.is_running(): if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop) fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
+60 -2
View File
@@ -22,8 +22,9 @@ from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from app.channels.base import Channel from app.channels.base import Channel
from app.channels.commands import is_known_channel_command from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -253,6 +254,7 @@ class WechatChannel(Channel):
self._state_dir = self._resolve_state_dir(config.get("state_dir")) self._state_dir = self._resolve_state_dir(config.get("state_dir"))
self._cursor_path = self._state_dir / "wechat-getupdates.json" if self._state_dir else None self._cursor_path = self._state_dir / "wechat-getupdates.json" if self._state_dir else None
self._auth_path = self._state_dir / "wechat-auth.json" if self._state_dir else None self._auth_path = self._state_dir / "wechat-auth.json" if self._state_dir else None
self._connection_repo = config.get("connection_repo")
self._load_state() self._load_state()
async def start(self) -> None: async def start(self) -> None:
@@ -617,6 +619,16 @@ class WechatChannel(Channel):
if thread_ts: if thread_ts:
self._context_tokens_by_thread[thread_ts] = context_token self._context_tokens_by_thread[thread_ts] = context_token
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
handled = await self._bind_connection_from_connect_code(
chat_id=chat_id,
context_token=context_token,
code=connect_code,
)
if handled:
return
inbound = self._make_inbound( inbound = self._make_inbound(
chat_id=chat_id, chat_id=chat_id,
user_id=chat_id, user_id=chat_id,
@@ -632,8 +644,54 @@ class WechatChannel(Channel):
}, },
) )
inbound.topic_id = None inbound.topic_id = None
inbound = await self._attach_connection_identity(inbound)
await self.bus.publish_inbound(inbound) await self.bus.publish_inbound(inbound)
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="wechat",
workspace_id=inbound.chat_id,
)
async def _bind_connection_from_connect_code(self, *, chat_id: str, context_token: str, code: str) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="wechat", state=code)
if state is None:
await self._send_connection_reply(chat_id, context_token, "WeChat connection code is invalid or expired.")
return True
if not chat_id:
await self._send_connection_reply(chat_id, context_token, "WeChat connection could not be completed from this message.")
return True
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="wechat",
external_account_id=chat_id,
workspace_id=chat_id,
metadata={
"context_token": context_token,
},
status="connected",
)
await self._send_connection_reply(chat_id, context_token, "WeChat connected to DeerFlow.")
return True
async def _send_connection_reply(self, chat_id: str, context_token: str, text: str) -> None:
if not context_token:
return
await self._send_text_message(
chat_id=chat_id,
context_token=context_token,
text=text,
client_id_prefix="deerflow-connect",
max_retries=1,
)
async def _ensure_authenticated(self) -> bool: async def _ensure_authenticated(self) -> bool:
async with self._auth_lock: async with self._auth_lock:
if self._bot_token: if self._bot_token:
+58 -1
View File
@@ -8,8 +8,10 @@ from collections.abc import Awaitable, Callable
from typing import Any, cast from typing import Any, cast
from app.channels.base import Channel from app.channels.base import Channel
from app.channels.commands import is_known_channel_command from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import ( from app.channels.message_bus import (
InboundMessage,
InboundMessageType, InboundMessageType,
MessageBus, MessageBus,
OutboundMessage, OutboundMessage,
@@ -29,6 +31,7 @@ class WeComChannel(Channel):
self._ws_frames: dict[str, dict[str, Any]] = {} self._ws_frames: dict[str, dict[str, Any]] = {}
self._ws_stream_ids: dict[str, str] = {} self._ws_stream_ids: dict[str, str] = {}
self._working_message = "Working on it..." self._working_message = "Working on it..."
self._connection_repo = config.get("connection_repo")
@property @property
def supports_streaming(self) -> bool: def supports_streaming(self) -> bool:
@@ -271,6 +274,16 @@ class WeComChannel(Channel):
user_id = (body.get("from") or {}).get("userid") user_id = (body.get("from") or {}).get("userid")
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
handled = await self._bind_connection_from_connect_code(
frame=frame,
user_id=str(user_id or ""),
code=connect_code,
)
if handled:
return
inbound_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT inbound_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
inbound = self._make_inbound( inbound = self._make_inbound(
chat_id=user_id, # keep user's conversation in memory chat_id=user_id, # keep user's conversation in memory
@@ -292,8 +305,52 @@ class WeComChannel(Channel):
except Exception: except Exception:
pass pass
inbound = await self._attach_connection_identity(inbound)
await self.bus.publish_inbound(inbound) await self.bus.publish_inbound(inbound)
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="wecom",
workspace_id=str(inbound.metadata.get("aibotid") or "") or None,
fallback_without_workspace=True,
)
async def _bind_connection_from_connect_code(self, *, frame: dict[str, Any], user_id: str, code: str) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="wecom", state=code)
if state is None:
await self._send_connection_reply(frame, "WeCom connection code is invalid or expired.")
return True
if not user_id:
await self._send_connection_reply(frame, "WeCom connection could not be completed from this message.")
return True
body = frame.get("body", {}) or {}
workspace_id = str(body.get("aibotid") or "") or None
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="wecom",
external_account_id=user_id,
workspace_id=workspace_id,
metadata={
"aibotid": workspace_id,
"chattype": body.get("chattype"),
},
status="connected",
)
await self._send_connection_reply(frame, "WeCom connected to DeerFlow.")
return True
async def _send_connection_reply(self, frame: dict[str, Any], text: str) -> None:
if not self._ws_client:
return
await self._ws_client.reply(frame, {"msgtype": "text", "text": {"content": text}})
async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None: async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
if not self._ws_client: if not self._ws_client:
return return
+4
View File
@@ -16,6 +16,7 @@ from app.gateway.routers import (
artifacts, artifacts,
assistants_compat, assistants_compat,
auth, auth,
channel_connections,
channels, channels,
feedback, feedback,
mcp, mcp,
@@ -384,6 +385,9 @@ This gateway provides runtime endpoints for agent runs plus custom endpoints for
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions # Suggestions API is mounted at /api/threads/{thread_id}/suggestions
app.include_router(suggestions.router) app.include_router(suggestions.router)
# User-facing IM channel connection API is mounted at /api/channels
app.include_router(channel_connections.router)
# Channels API is mounted at /api/channels # Channels API is mounted at /api/channels
app.include_router(channels.router) app.include_router(channels.router)
+4 -2
View File
@@ -6,9 +6,11 @@ import logging
import os import os
from types import SimpleNamespace from types import SimpleNamespace
from deerflow.runtime.user_context import DEFAULT_USER_ID
AUTH_DISABLED_ENV_VAR = "DEER_FLOW_AUTH_DISABLED" AUTH_DISABLED_ENV_VAR = "DEER_FLOW_AUTH_DISABLED"
AUTH_DISABLED_USER_ID = "e2e-user" AUTH_DISABLED_USER_ID = DEFAULT_USER_ID
AUTH_DISABLED_USER_EMAIL = "e2e@test.local" AUTH_DISABLED_USER_EMAIL = "default@test.local"
AUTH_SOURCE_SESSION = "session" AUTH_SOURCE_SESSION = "session"
AUTH_SOURCE_INTERNAL = "internal" AUTH_SOURCE_INTERNAL = "internal"
+18
View File
@@ -276,6 +276,8 @@ def require_permission(
# strict-deny rather than strict-allow — only an *existing* # strict-deny rather than strict-allow — only an *existing*
# row with a *different* user_id triggers 404. # row with a *different* user_id triggers 404.
if owner_check: if owner_check:
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
thread_id = kwargs.get("thread_id") thread_id = kwargs.get("thread_id")
if thread_id is None: if thread_id is None:
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter") raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
@@ -288,6 +290,22 @@ def require_permission(
str(auth.user.id), str(auth.user.id),
require_existing=require_existing, require_existing=require_existing,
) )
if not allowed and getattr(auth.user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
# Trusted internal callers (channel workers) also act for
# the connection owner carried in X-DeerFlow-Owner-User-Id.
# Scope the check to that owner instead of bypassing it; a
# leaked internal token must not grant cross-user thread
# access. The header is honored only after ``auth`` proved
# the caller holds the internal token (mirrors
# get_trusted_internal_owner_user_id, which keys off the
# middleware-stamped ``request.state.user``).
header_owner = (request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME) or "").strip()
if header_owner:
allowed = await thread_store.check_access(
thread_id,
header_owner,
require_existing=require_existing,
)
if not allowed: if not allowed:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
+25 -2
View File
@@ -5,10 +5,12 @@ from __future__ import annotations
import os import os
import secrets import secrets
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any
from deerflow.runtime.user_context import DEFAULT_USER_ID from deerflow.runtime.user_context import DEFAULT_USER_ID
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token" INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
INTERNAL_OWNER_USER_ID_HEADER_NAME = "X-DeerFlow-Owner-User-Id"
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN" INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
INTERNAL_SYSTEM_ROLE = "internal" INTERNAL_SYSTEM_ROLE = "internal"
@@ -23,9 +25,12 @@ def _load_internal_auth_token() -> str:
_INTERNAL_AUTH_TOKEN = _load_internal_auth_token() _INTERNAL_AUTH_TOKEN = _load_internal_auth_token()
def create_internal_auth_headers() -> dict[str, str]: def create_internal_auth_headers(*, owner_user_id: str | None = None) -> dict[str, str]:
"""Return headers that authenticate trusted Gateway internal calls.""" """Return headers that authenticate trusted Gateway internal calls."""
return {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN} headers = {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
if owner_user_id:
headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] = owner_user_id
return headers
def is_valid_internal_auth_token(token: str | None) -> bool: def is_valid_internal_auth_token(token: str | None) -> bool:
@@ -36,3 +41,21 @@ def is_valid_internal_auth_token(token: str | None) -> bool:
def get_internal_user(): def get_internal_user():
"""Return the synthetic user used for trusted internal channel calls.""" """Return the synthetic user used for trusted internal channel calls."""
return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE) return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE)
def get_trusted_internal_owner_user_id(request: Any) -> str | None:
"""Return the owner override for a trusted internal request, if present.
The header is ignored for normal browser/API callers. It is only honored
after ``AuthMiddleware`` has validated the internal auth token and stamped
the synthetic internal user onto ``request.state.user``.
"""
user = getattr(getattr(request, "state", None), "user", None)
if getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE:
return None
owner_user_id = request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME)
if not owner_user_id:
return None
owner_user_id = owner_user_id.strip()
return owner_user_id or None
@@ -0,0 +1,670 @@
"""Browser-facing APIs for user-owned IM channel bindings."""
from __future__ import annotations
import asyncio
import logging
import secrets
from datetime import UTC, datetime, timedelta
from typing import Any
from fastapi import APIRouter, HTTPException, Request, Response
from pydantic import BaseModel, Field
from app.channels.runtime_config_store import (
ChannelRuntimeConfigStore,
apply_runtime_connection_config,
merge_runtime_channel_configs,
)
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
from deerflow.persistence.channel_connections import ChannelConnectionRepository
from deerflow.persistence.engine import get_session_factory
router = APIRouter(prefix="/api/channels", tags=["channel-connections"])
logger = logging.getLogger(__name__)
_STATE_TTL_SECONDS = 600
_MASKED_CREDENTIAL_VALUE = "********"
class ChannelCredentialFieldResponse(BaseModel):
name: str
label: str
type: str = "text"
required: bool = True
class ChannelProviderResponse(BaseModel):
provider: str
display_name: str
enabled: bool
configured: bool
connectable: bool
unavailable_reason: str | None = None
auth_mode: str
connection_status: str
credential_fields: list[ChannelCredentialFieldResponse] = Field(default_factory=list)
credential_values: dict[str, str] = Field(default_factory=dict)
class ChannelProvidersResponse(BaseModel):
enabled: bool
providers: list[ChannelProviderResponse]
class ChannelConnectionResponse(BaseModel):
id: str
provider: str
status: str
external_account_id: str | None = None
external_account_name: str | None = None
workspace_id: str | None = None
workspace_name: str | None = None
scopes: list[str] = Field(default_factory=list)
metadata: dict[str, Any] = Field(default_factory=dict)
class ChannelConnectionsResponse(BaseModel):
connections: list[ChannelConnectionResponse]
class ChannelConnectResponse(BaseModel):
provider: str
mode: str
url: str | None = None
code: str
instruction: str
expires_in: int
class ChannelRuntimeConfigRequest(BaseModel):
values: dict[str, str] = Field(default_factory=dict)
_PROVIDER_META: dict[str, dict[str, str]] = {
"telegram": {"display_name": "Telegram", "auth_mode": "deep_link"},
"slack": {"display_name": "Slack", "auth_mode": "binding_code"},
"discord": {"display_name": "Discord", "auth_mode": "binding_code"},
"feishu": {"display_name": "Feishu", "auth_mode": "binding_code"},
"dingtalk": {"display_name": "DingTalk", "auth_mode": "binding_code"},
"wechat": {"display_name": "WeChat", "auth_mode": "binding_code"},
"wecom": {"display_name": "WeCom", "auth_mode": "binding_code"},
}
_CREDENTIAL_FIELDS: dict[str, tuple[dict[str, str], ...]] = {
"telegram": (
{"name": "bot_token", "label": "Bot token", "type": "password"},
{"name": "bot_username", "label": "Bot username", "type": "text"},
),
"slack": (
{"name": "bot_token", "label": "Bot token", "type": "password"},
{"name": "app_token", "label": "App token", "type": "password"},
),
"discord": ({"name": "bot_token", "label": "Bot token", "type": "password"},),
"feishu": (
{"name": "app_id", "label": "App ID", "type": "text"},
{"name": "app_secret", "label": "App secret", "type": "password"},
),
"dingtalk": (
{"name": "client_id", "label": "Client ID", "type": "text"},
{"name": "client_secret", "label": "Client secret", "type": "password"},
),
"wechat": ({"name": "bot_token", "label": "Bot token", "type": "password"},),
"wecom": (
{"name": "bot_id", "label": "Bot ID", "type": "text"},
{"name": "bot_secret", "label": "Bot secret", "type": "password"},
),
}
_RUNTIME_REQUIREMENTS: dict[str, tuple[str, ...]] = {
"telegram": ("bot_token",),
"slack": ("bot_token", "app_token"),
"discord": ("bot_token",),
"feishu": ("app_id", "app_secret"),
"dingtalk": ("client_id", "client_secret"),
"wechat": ("bot_token",),
"wecom": ("bot_id", "bot_secret"),
}
def _get_user_id(request: Request) -> str:
user = getattr(request.state, "user", None)
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
return str(user.id)
async def _require_admin_user(request: Request) -> None:
"""Require an admin caller for instance-wide channel runtime mutations.
Runtime credentials and the channel workers they start/stop are shared by
every user of the deployment, so only admins may change them (same model
as the MCP config API). Auth-disabled local mode uses a synthetic admin
user and is unaffected.
"""
user = getattr(request.state, "user", None)
if user is None:
from app.gateway.deps import get_current_user_from_request
user = await get_current_user_from_request(request)
if getattr(user, "system_role", None) != "admin":
raise HTTPException(status_code=403, detail="Admin privileges required to manage channel runtime credentials.")
def _get_app_config():
from deerflow.config.app_config import get_app_config
return get_app_config()
async def _get_runtime_config_store(request: Request) -> ChannelRuntimeConfigStore:
store = getattr(request.app.state, "channel_runtime_config_store", None)
if isinstance(store, ChannelRuntimeConfigStore):
return store
# Constructing the store reads its JSON file from disk; keep it off the
# event loop.
store = await asyncio.to_thread(ChannelRuntimeConfigStore)
request.app.state.channel_runtime_config_store = store
return store
async def _get_channel_connections_config(request: Request) -> ChannelConnectionsConfig:
config = getattr(request.app.state, "channel_connections_config", None)
if not isinstance(config, ChannelConnectionsConfig):
config = _get_app_config().channel_connections
config = apply_runtime_connection_config(config, store=await _get_runtime_config_store(request))
request.app.state.channel_connections_config = config
return config
async def _get_channels_config(request: Request) -> dict[str, Any]:
state_config = getattr(request.app.state, "channels_config", None)
if isinstance(state_config, dict):
return state_config
result = await _load_channels_config(request, await _get_channel_connections_config(request))
request.app.state.channels_config = result
return result
async def _load_channels_config(request: Request, config: ChannelConnectionsConfig) -> dict[str, Any]:
app_config = _get_app_config()
extra = app_config.model_extra or {}
channels_config = extra.get("channels")
result = dict(channels_config) if isinstance(channels_config, dict) else {}
merge_runtime_channel_configs(
result,
config,
store=await _get_runtime_config_store(request),
)
return result
def _get_repository(request: Request, config: ChannelConnectionsConfig) -> ChannelConnectionRepository:
repo = getattr(request.app.state, "channel_connection_repo", None)
if isinstance(repo, ChannelConnectionRepository):
return repo
sf = get_session_factory()
if sf is None:
raise HTTPException(status_code=503, detail="Channel connection persistence is not available")
repo = ChannelConnectionRepository(sf)
request.app.state.channel_connection_repo = repo
return repo
def _provider_config(config: ChannelConnectionsConfig, provider: str):
provider_config = getattr(config, provider, None)
if provider_config is None:
raise HTTPException(status_code=404, detail="Unknown channel provider")
return provider_config
def _runtime_channel_configured(provider: str, channels_config: dict[str, Any]) -> bool:
runtime_config = channels_config.get(provider)
if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False):
return False
return all(str(runtime_config.get(key) or "").strip() for key in _RUNTIME_REQUIREMENTS[provider])
def _runtime_unavailable_reason(provider: str) -> str:
meta = _PROVIDER_META.get(provider)
display_name = meta["display_name"] if meta else provider
return f"Enter the required {display_name} credentials to connect this channel."
def _runtime_not_running_reason(provider: str) -> str:
meta = _PROVIDER_META.get(provider)
display_name = meta["display_name"] if meta else provider
return f"{display_name} channel is configured but is not running. Check the credentials and service logs."
def _runtime_channel_running(provider: str) -> bool | None:
try:
from app.channels.service import get_channel_service
except Exception:
logger.debug("Unable to inspect channel service status", exc_info=True)
return None
service = get_channel_service()
if service is None:
return None
try:
status = service.get_status()
except Exception:
logger.debug("Unable to read channel service status", exc_info=True)
return None
if not status.get("service_running"):
return False
channel_status = status.get("channels", {}).get(provider)
if not isinstance(channel_status, dict):
return None
return bool(channel_status.get("running"))
async def _ensure_runtime_channel_ready_if_available(
provider: str,
channels_config: dict[str, Any],
) -> bool | None:
runtime_config = channels_config.get(provider)
if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False):
return None
try:
from app.channels.service import get_channel_service
except Exception:
logger.debug("Unable to import channel service for readiness reconciliation", exc_info=True)
return None
service = get_channel_service()
if service is None:
return None
ensure_channel_ready = getattr(service, "ensure_channel_ready", None)
if ensure_channel_ready is None:
return None
try:
return await ensure_channel_ready(provider, runtime_config)
except Exception:
logger.exception("Failed to reconcile runtime channel readiness")
return False
def _provider_unavailable_reason(
config: ChannelConnectionsConfig,
channels_config: dict[str, Any],
provider: str,
) -> str | None:
provider_config = _provider_config(config, provider)
if not provider_config.enabled:
return None
if not provider_config.configured:
return _runtime_unavailable_reason(provider)
if not _runtime_channel_configured(provider, channels_config):
return _runtime_unavailable_reason(provider)
if _runtime_channel_running(provider) is False:
return _runtime_not_running_reason(provider)
return None
def _provider_status(
config: ChannelConnectionsConfig,
channels_config: dict[str, Any],
provider: str,
) -> tuple[dict[str, bool], str | None]:
declared = config.provider_status(provider)
unavailable_reason = _provider_unavailable_reason(config, channels_config, provider)
configured = declared["configured"] and _runtime_channel_configured(provider, channels_config)
return {"enabled": declared["enabled"], "configured": configured}, unavailable_reason
def _new_binding_code() -> str:
return secrets.token_urlsafe(16)
async def _create_state(
repo: ChannelConnectionRepository,
*,
owner_user_id: str,
provider: str,
) -> str:
state = _new_binding_code()
await repo.create_oauth_state(
owner_user_id=owner_user_id,
provider=provider,
state=state,
expires_at=datetime.now(UTC) + timedelta(seconds=_STATE_TTL_SECONDS),
)
return state
def _connect_instruction(provider: str, code: str) -> str:
if provider == "telegram":
return f"Send /start {code} to the DeerFlow Telegram bot."
meta = _PROVIDER_META.get(provider)
if meta is None:
raise HTTPException(status_code=404, detail="Unknown channel provider")
return f"Send /connect {code} to the DeerFlow {meta['display_name']} bot."
def _connect_url(config: ChannelConnectionsConfig, provider: str, code: str) -> str | None:
if provider == "telegram":
provider_config = _provider_config(config, provider)
return f"https://t.me/{provider_config.bot_username}?start={code}"
if _PROVIDER_META.get(provider, {}).get("auth_mode") == "binding_code":
return None
raise HTTPException(status_code=404, detail="Unknown channel provider")
def _connection_updated_at(connection: dict[str, Any]) -> datetime:
value = connection.get("updated_at")
if isinstance(value, datetime):
return value if value.tzinfo is not None else value.replace(tzinfo=UTC)
if isinstance(value, str) and value:
try:
return datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
pass
return datetime.min.replace(tzinfo=UTC)
def _newest_connection_by_provider(connections: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
by_provider: dict[str, dict[str, Any]] = {}
for item in connections:
existing = by_provider.get(item["provider"])
if existing is None or _connection_updated_at(item) > _connection_updated_at(existing):
by_provider[item["provider"]] = item
return by_provider
def _credential_fields(provider: str) -> list[ChannelCredentialFieldResponse]:
fields = _CREDENTIAL_FIELDS.get(provider)
if fields is None:
raise HTTPException(status_code=404, detail="Unknown channel provider")
return [ChannelCredentialFieldResponse(**field) for field in fields]
def _credential_values(provider: str, channels_config: dict[str, Any]) -> dict[str, str]:
runtime_config = channels_config.get(provider)
if not isinstance(runtime_config, dict):
return {}
values: dict[str, str] = {}
for field in _credential_fields(provider):
value = str(runtime_config.get(field.name) or "").strip()
if not value:
continue
values[field.name] = _MASKED_CREDENTIAL_VALUE if field.type == "password" else value
return values
def _provider_response(
config: ChannelConnectionsConfig,
channels_config: dict[str, Any],
provider: str,
meta: dict[str, str],
connection: dict[str, Any] | None = None,
) -> ChannelProviderResponse:
from app.gateway.auth_disabled import is_auth_disabled
status, unavailable_reason = _provider_status(config, channels_config, provider)
if connection:
connection_status = connection["status"]
elif is_auth_disabled() and status["configured"] and unavailable_reason is None:
# Auth-disabled local mode routes every channel message to the default
# user, so a configured running channel needs no per-user binding.
connection_status = "connected"
else:
connection_status = "not_connected"
credential_values = _credential_values(provider, channels_config)
if provider == "telegram" and not credential_values.get("bot_username"):
bot_username = str(_provider_config(config, provider).bot_username or "").strip()
if bot_username:
credential_values["bot_username"] = bot_username
return ChannelProviderResponse(
provider=provider,
display_name=meta["display_name"],
enabled=status["enabled"],
configured=status["configured"],
connectable=status["enabled"] and status["configured"] and unavailable_reason is None,
unavailable_reason=unavailable_reason,
auth_mode=meta["auth_mode"],
connection_status=connection_status,
credential_fields=_credential_fields(provider),
credential_values=credential_values,
)
def _required_runtime_values(
provider: str,
values: dict[str, str],
existing_config: dict[str, Any] | None = None,
) -> dict[str, str]:
fields = _credential_fields(provider)
cleaned: dict[str, str] = {}
missing: list[str] = []
existing_config = existing_config or {}
for field in fields:
raw_value = values.get(field.name, "")
if field.type == "password" and raw_value == _MASKED_CREDENTIAL_VALUE:
existing_value = str(existing_config.get(field.name) or "").strip()
if existing_value:
cleaned[field.name] = existing_value
continue
value = raw_value.strip() if isinstance(raw_value, str) else str(raw_value or "").strip()
if field.required and not value:
missing.append(field.label)
cleaned[field.name] = value
if missing:
raise HTTPException(status_code=400, detail=f"Missing required channel configuration: {', '.join(missing)}")
return cleaned
async def _restart_runtime_channel_if_available(provider: str, runtime_config: dict[str, Any]) -> bool | None:
try:
from app.channels.service import get_channel_service
except Exception:
logger.exception("Failed to import channel service while configuring a runtime channel")
return None
service = get_channel_service()
if service is None:
return None
return await service.configure_channel(provider, runtime_config)
async def _sync_runtime_channel_after_removal(provider: str, channels_config: dict[str, Any]) -> bool | None:
try:
from app.channels.service import get_channel_service
except Exception:
logger.exception("Failed to import channel service while disconnecting a runtime channel")
return None
service = get_channel_service()
if service is None:
return None
runtime_config = channels_config.get(provider)
if isinstance(runtime_config, dict) and runtime_config.get("enabled", False):
return await service.configure_channel(provider, runtime_config)
return await service.remove_channel(provider)
@router.get("/providers", response_model=ChannelProvidersResponse)
async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
config = await _get_channel_connections_config(request)
channels_config = await _get_channels_config(request)
repo = None
if config.enabled:
try:
repo = _get_repository(request, config)
except HTTPException as exc:
if exc.status_code != 503:
raise
owner_user_id = _get_user_id(request)
connections = await repo.list_connections(owner_user_id) if repo is not None else []
by_provider = _newest_connection_by_provider(connections)
enabled_providers = [provider for provider in _PROVIDER_META if config.provider_status(provider)["enabled"]]
# Readiness reconciliation is independent per provider; run it
# concurrently so one slow channel restart does not serialize the
# whole /providers response.
await asyncio.gather(
*(_ensure_runtime_channel_ready_if_available(provider, channels_config) for provider in enabled_providers if _runtime_channel_configured(provider, channels_config)),
)
providers: list[ChannelProviderResponse] = []
for provider in enabled_providers:
connection = by_provider.get(provider)
providers.append(_provider_response(config, channels_config, provider, _PROVIDER_META[provider], connection))
return ChannelProvidersResponse(enabled=config.enabled, providers=providers)
@router.get("/connections", response_model=ChannelConnectionsResponse)
async def get_channel_connections(request: Request) -> ChannelConnectionsResponse:
config = await _get_channel_connections_config(request)
if not config.enabled:
return ChannelConnectionsResponse(connections=[])
repo = _get_repository(request, config)
rows = await repo.list_connections(_get_user_id(request))
return ChannelConnectionsResponse(connections=[ChannelConnectionResponse(**row) for row in rows])
@router.delete("/connections/{connection_id}", status_code=204)
async def disconnect_channel_connection(connection_id: str, request: Request) -> Response:
config = await _get_channel_connections_config(request)
if not config.enabled:
raise HTTPException(status_code=400, detail="Channel connections are disabled")
repo = _get_repository(request, config)
disconnected = await repo.disconnect_connection(
connection_id=connection_id,
owner_user_id=_get_user_id(request),
)
if not disconnected:
raise HTTPException(status_code=404, detail="Channel connection not found")
return Response(status_code=204)
@router.delete("/{provider}/runtime-config", response_model=ChannelProviderResponse)
async def disconnect_channel_provider_runtime(provider: str, request: Request) -> ChannelProviderResponse:
await _require_admin_user(request)
config = await _get_channel_connections_config(request)
if not config.enabled:
raise HTTPException(status_code=400, detail="Channel connections are disabled")
provider_config = _provider_config(config, provider)
if not provider_config.enabled:
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
owner_user_id = _get_user_id(request)
try:
repo = _get_repository(request, config)
except HTTPException as exc:
if exc.status_code != 503:
raise
repo = None
if repo is not None:
for connection in await repo.list_connections(owner_user_id):
if connection["provider"] == provider and connection["status"] != "revoked":
await repo.disconnect_connection(
connection_id=connection["id"],
owner_user_id=owner_user_id,
)
store = await _get_runtime_config_store(request)
await asyncio.to_thread(store.set_provider_disconnected, provider)
channels_config = await _load_channels_config(request, config)
request.app.state.channels_config = channels_config
stopped = await _sync_runtime_channel_after_removal(provider, channels_config)
if stopped is False:
display_name = _PROVIDER_META[provider]["display_name"]
raise HTTPException(status_code=400, detail=f"Failed to stop {display_name} channel. Try again.")
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
@router.post("/{provider}/connect", response_model=ChannelConnectResponse)
async def connect_channel_provider(provider: str, request: Request) -> ChannelConnectResponse:
config = await _get_channel_connections_config(request)
channels_config = await _get_channels_config(request)
if not config.enabled:
raise HTTPException(status_code=400, detail="Channel connections are disabled")
provider_config = _provider_config(config, provider)
if provider_config.enabled and _runtime_channel_configured(provider, channels_config):
await _ensure_runtime_channel_ready_if_available(provider, channels_config)
status, unavailable_reason = _provider_status(config, channels_config, provider)
if not status["enabled"]:
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
if unavailable_reason:
raise HTTPException(status_code=400, detail=unavailable_reason)
if not status["configured"]:
raise HTTPException(status_code=400, detail="Channel provider is not configured")
repo = _get_repository(request, config)
code = await _create_state(
repo,
owner_user_id=_get_user_id(request),
provider=provider,
)
return ChannelConnectResponse(
provider=provider,
mode=_PROVIDER_META[provider]["auth_mode"],
url=_connect_url(config, provider, code),
code=code,
instruction=_connect_instruction(provider, code),
expires_in=_STATE_TTL_SECONDS,
)
@router.post("/{provider}/runtime-config", response_model=ChannelProviderResponse)
async def configure_channel_provider_runtime(
provider: str,
body: ChannelRuntimeConfigRequest,
request: Request,
) -> ChannelProviderResponse:
await _require_admin_user(request)
config = await _get_channel_connections_config(request)
if not config.enabled:
raise HTTPException(status_code=400, detail="Channel connections are disabled")
provider_config = _provider_config(config, provider)
if not provider_config.enabled:
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
channels_config = await _get_channels_config(request)
existing = channels_config.get(provider)
runtime_config = dict(existing) if isinstance(existing, dict) else {}
values = _required_runtime_values(provider, body.values, runtime_config)
runtime_config["enabled"] = True
for key in _RUNTIME_REQUIREMENTS[provider]:
runtime_config[key] = values[key]
if provider == "telegram":
# The deep-link username is persisted with the runtime channel config
# (set_provider_config below) and applied to future requests via
# apply_runtime_connection_config; never mutate the config instance
# cached by get_app_config().
runtime_config["bot_username"] = values["bot_username"]
channels_config[provider] = runtime_config
request.app.state.channels_config = channels_config
started = await _restart_runtime_channel_if_available(provider, runtime_config)
if started is False:
display_name = _PROVIDER_META[provider]["display_name"]
raise HTTPException(status_code=400, detail=f"Failed to start {display_name} channel. Check the values and try again.")
store = await _get_runtime_config_store(request)
await asyncio.to_thread(store.set_provider_config, provider, runtime_config)
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
+11 -1
View File
@@ -22,6 +22,7 @@ from pydantic import BaseModel, Field, field_validator
from app.gateway.authz import require_permission from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer from app.gateway.deps import get_checkpointer
from app.gateway.internal_auth import get_trusted_internal_owner_user_id
from app.gateway.utils import sanitize_log_param from app.gateway.utils import sanitize_log_param
from deerflow.config.paths import Paths, get_paths from deerflow.config.paths import Paths, get_paths
from deerflow.runtime import serialize_channel_values from deerflow.runtime import serialize_channel_values
@@ -257,11 +258,19 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
thread_store = get_thread_store(request) thread_store = get_thread_store(request)
thread_id = body.thread_id or str(uuid.uuid4()) thread_id = body.thread_id or str(uuid.uuid4())
now = now_iso() now = now_iso()
thread_owner_user_id = get_trusted_internal_owner_user_id(request)
thread_owner_kwargs = {"user_id": thread_owner_user_id} if thread_owner_user_id else {}
# ``body.metadata`` is already stripped of server-reserved keys by # ``body.metadata`` is already stripped of server-reserved keys by
# ``ThreadCreateRequest._strip_reserved`` — see the model definition. # ``ThreadCreateRequest._strip_reserved`` — see the model definition.
# Idempotency: return existing record when already present # Idempotency: return existing record when already present
existing_record = await thread_store.get(thread_id) existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
if existing_record is None and thread_owner_user_id:
unscoped_record = await thread_store.get(thread_id, user_id=None)
if unscoped_record is not None:
if unscoped_record.get("user_id") != thread_owner_user_id:
await thread_store.update_owner(thread_id, thread_owner_user_id, user_id=None)
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
if existing_record is not None: if existing_record is not None:
return ThreadResponse( return ThreadResponse(
thread_id=thread_id, thread_id=thread_id,
@@ -276,6 +285,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
await thread_store.create( await thread_store.create(
thread_id, thread_id,
assistant_id=getattr(body, "assistant_id", None), assistant_id=getattr(body, "assistant_id", None),
**thread_owner_kwargs,
metadata=body.metadata, metadata=body.metadata,
) )
except Exception: except Exception:
+84 -61
View File
@@ -12,6 +12,7 @@ import json
import logging import logging
import re import re
from collections.abc import Mapping from collections.abc import Mapping
from types import SimpleNamespace
from typing import Any from typing import Any
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
@@ -19,7 +20,7 @@ from langchain_core.messages import BaseMessage
from langchain_core.messages.utils import convert_to_messages from langchain_core.messages.utils import convert_to_messages
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE, get_trusted_internal_owner_user_id
from app.gateway.utils import sanitize_log_param from app.gateway.utils import sanitize_log_param
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import get_app_config
from deerflow.runtime import ( from deerflow.runtime import (
@@ -35,6 +36,7 @@ from deerflow.runtime import (
run_agent, run_agent,
) )
from deerflow.runtime.runs.naming import resolve_root_run_name from deerflow.runtime.runs.naming import resolve_root_run_name
from deerflow.runtime.user_context import reset_current_user, set_current_user
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -315,6 +317,7 @@ async def start_run(
detail=f"Model {model_name!r} is not in the configured model allowlist", detail=f"Model {model_name!r} is not in the configured model allowlist",
) )
owner_user_id = get_trusted_internal_owner_user_id(request)
# Stateless run endpoints carry thread_id in the request *body*, so the # Stateless run endpoints carry thread_id in the request *body*, so the
# @require_permission(owner_check=True) decorator -- which resolves ownership # @require_permission(owner_check=True) decorator -- which resolves ownership
# from the path param -- cannot protect them. Enforce thread ownership here, # from the path param -- cannot protect them. Enforce thread ownership here,
@@ -323,79 +326,99 @@ async def start_run(
# temp threads) and NULL-owner rows (shared / pre-auth data) stay accessible # temp threads) and NULL-owner rows (shared / pre-auth data) stay accessible
# via check_access; only a thread already owned by another user is rejected # via check_access; only a thread already owned by another user is rejected
# with 404, matching thread_runs.py's anti-enumeration behaviour. Internal # with 404, matching thread_runs.py's anti-enumeration behaviour. Internal
# channel runs act on behalf of IM users they do not own (see # channel runs act on behalf of the connection owner carried in
# inject_authenticated_user_context), so the internal system role is exempt. # X-DeerFlow-Owner-User-Id, so they are scoped to that owner instead of
# bypassing the check -- a leaked internal token must not grant cross-user
# thread access.
user = getattr(request.state, "user", None) user = getattr(request.state, "user", None)
if user is not None and getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE: if user is not None:
if not await run_ctx.thread_store.check_access(thread_id, str(user.id)): allowed = await run_ctx.thread_store.check_access(thread_id, str(user.id))
if not allowed and owner_user_id and getattr(user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
# Channel workers may also act for the connection owner named in
# the trusted header (e.g. claiming a legacy default-owned channel
# thread for its real owner).
allowed = await run_ctx.thread_store.check_access(thread_id, owner_user_id)
if not allowed:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
owner_context_token = set_current_user(SimpleNamespace(id=owner_user_id)) if owner_user_id else None
try: try:
record = await run_mgr.create_or_reject( try:
thread_id, record = await run_mgr.create_or_reject(
body.assistant_id,
on_disconnect=disconnect,
metadata=body.metadata or {},
kwargs={"input": body.input, "config": body.config},
multitask_strategy=body.multitask_strategy,
model_name=model_name,
)
except ConflictError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
except UnsupportedStrategyError as exc:
raise HTTPException(status_code=501, detail=str(exc)) from exc
# Upsert thread metadata so the thread appears in /threads/search,
# even for threads that were never explicitly created via POST /threads
# (e.g. stateless runs).
try:
existing = await run_ctx.thread_store.get(thread_id)
if existing is None:
await run_ctx.thread_store.create(
thread_id, thread_id,
assistant_id=body.assistant_id, body.assistant_id,
metadata=body.metadata, on_disconnect=disconnect,
metadata=body.metadata or {},
kwargs={"input": body.input, "config": body.config},
multitask_strategy=body.multitask_strategy,
model_name=model_name,
user_id=owner_user_id,
) )
else: except ConflictError as exc:
await run_ctx.thread_store.update_status(thread_id, "running") raise HTTPException(status_code=409, detail=str(exc)) from exc
except Exception: except UnsupportedStrategyError as exc:
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id)) raise HTTPException(status_code=501, detail=str(exc)) from exc
agent_factory = resolve_agent_factory(body.assistant_id) # Upsert thread metadata so the thread appears in /threads/search,
graph_input = normalize_input(body.input) # even for threads that were never explicitly created via POST /threads
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id) # (e.g. stateless runs).
try:
existing = await run_ctx.thread_store.get(thread_id)
if existing is None and owner_user_id:
unscoped_existing = await run_ctx.thread_store.get(thread_id, user_id=None)
if unscoped_existing is not None:
if unscoped_existing.get("user_id") != owner_user_id:
await run_ctx.thread_store.update_owner(thread_id, owner_user_id, user_id=None)
existing = await run_ctx.thread_store.get(thread_id)
if existing is None:
await run_ctx.thread_store.create(
thread_id,
assistant_id=body.assistant_id,
metadata=body.metadata,
)
else:
await run_ctx.thread_store.update_status(thread_id, "running")
except Exception:
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``. agent_factory = resolve_agent_factory(body.assistant_id)
# The ``context`` field is a custom extension for the langgraph-compat layer graph_input = normalize_input(body.input)
# that carries agent configuration (model_name, thinking_enabled, etc.). config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
merge_run_context_overrides(config, getattr(body, "context", None))
inject_authenticated_user_context(config, request)
stream_modes = normalize_stream_modes(body.stream_mode) # Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
# The ``context`` field is a custom extension for the langgraph-compat layer
# that carries agent configuration (model_name, thinking_enabled, etc.).
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
merge_run_context_overrides(config, getattr(body, "context", None))
inject_authenticated_user_context(config, request)
task = asyncio.create_task( stream_modes = normalize_stream_modes(body.stream_mode)
run_agent(
bridge, task = asyncio.create_task(
run_mgr, run_agent(
record, bridge,
ctx=run_ctx, run_mgr,
agent_factory=agent_factory, record,
graph_input=graph_input, ctx=run_ctx,
config=config, agent_factory=agent_factory,
stream_modes=stream_modes, graph_input=graph_input,
stream_subgraphs=body.stream_subgraphs, config=config,
interrupt_before=body.interrupt_before, stream_modes=stream_modes,
interrupt_after=body.interrupt_after, stream_subgraphs=body.stream_subgraphs,
interrupt_before=body.interrupt_before,
interrupt_after=body.interrupt_after,
)
) )
) record.task = task
record.task = task
# Title sync is handled by worker.py's finally block which reads the # Title sync is handled by worker.py's finally block which reads the
# title from the checkpoint and calls thread_store.update_display_name # title from the checkpoint and calls thread_store.update_display_name
# after the run completes. # after the run completes.
return record return record
finally:
if owner_context_token is not None:
reset_current_user(owner_context_token)
async def sse_consumer( async def sse_consumer(
+122
View File
@@ -0,0 +1,122 @@
# IM Channel Connections
DeerFlow supports user-owned IM channel bindings for Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom. The feature reuses the existing `channels.*` runtime configuration, so it works in local and private deployments with the same outbound transports already supported by DeerFlow.
No public IP, OAuth callback URL, or provider webhook is required in this implementation.
## Configuration
Configure the actual IM bots under the existing `channels` block:
```yaml
channels:
telegram:
enabled: true
bot_token: $TELEGRAM_BOT_TOKEN
slack:
enabled: true
bot_token: $SLACK_BOT_TOKEN
app_token: $SLACK_APP_TOKEN
discord:
enabled: true
bot_token: $DISCORD_BOT_TOKEN
feishu:
enabled: true
app_id: $FEISHU_APP_ID
app_secret: $FEISHU_APP_SECRET
dingtalk:
enabled: true
client_id: $DINGTALK_CLIENT_ID
client_secret: $DINGTALK_CLIENT_SECRET
wechat:
enabled: true
bot_token: $WECHAT_BOT_TOKEN
wecom:
enabled: true
bot_id: $WECOM_BOT_ID
bot_secret: $WECOM_BOT_SECRET
```
Then enable user bindings in `channel_connections`:
```yaml
channel_connections:
enabled: true
telegram:
enabled: true
bot_username: $TELEGRAM_BOT_USERNAME
slack:
enabled: true
discord:
enabled: true
feishu:
enabled: true
dingtalk:
enabled: true
wechat:
enabled: true
wecom:
enabled: true
```
`channel_connections` does not duplicate provider secrets. It only controls the browser-facing connect UI and stores per-user binding records. Telegram needs `bot_username` only so the frontend can open a deep link.
## Connect Flow
Telegram:
- The frontend creates a short one-time code.
- The Connect button opens `https://t.me/<bot_username>?start=<code>`.
- The existing Telegram long-polling worker receives `/start <code>` and binds that Telegram chat/user to the current DeerFlow user.
Slack:
- The frontend creates a short one-time code.
- The UI shows `Send /connect <code> to the DeerFlow Slack bot.`
- The existing Slack Socket Mode worker receives the message and binds the Slack user/team to the current DeerFlow user.
Discord:
- The frontend creates a short one-time code.
- The UI shows `Send /connect <code> to the DeerFlow Discord bot.`
- The existing Discord Gateway worker receives the message and binds the Discord user/guild to the current DeerFlow user.
Feishu/Lark, DingTalk, WeChat, and WeCom:
- The frontend creates a short one-time code.
- The UI shows `Send /connect <code> to the DeerFlow <Provider> bot.`
- The already-running long-connection or polling worker receives the message and binds the platform user/workspace identity to the current DeerFlow user.
Codes use 128 bits of randomness, expire after 10 minutes, and are single-use.
## Runtime Model
Connection records live in SQL tables under `deerflow.persistence.channel_connections`:
- `channel_connections`: owner user, provider identity, workspace/guild/team, status, metadata.
- `channel_oauth_states`: one-time connect codes and Telegram deep-link state.
- `channel_conversations`: connection-scoped IM conversation to DeerFlow thread mapping.
- `channel_credentials`: reserved for future provider-token flows, not used by the local/private binding flow.
Incoming messages that resolve to a connection carry `connection_id`, `owner_user_id`, and `workspace_id`. `ChannelManager` uses `owner_user_id` as the DeerFlow run user id and preserves the raw platform user id as `channel_user_id`.
## Security Notes
- Browser APIs remain authenticated and CSRF-protected.
- Connect codes are 128-bit random, short-lived, and single-use.
- Provider bot tokens remain in `channels.*` and are never returned to the browser.
- Stored per-connection credentials are encrypted. If stored credential material cannot be decrypted, DeerFlow treats it as unavailable instead of using corrupt secrets.
- This implementation does not add public provider callback or webhook routes.
@@ -11,6 +11,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
from deerflow.config.acp_config import ACPAgentConfig, load_acp_config_from_dict from deerflow.config.acp_config import ACPAgentConfig, load_acp_config_from_dict
from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
from deerflow.config.database_config import DatabaseConfig from deerflow.config.database_config import DatabaseConfig
from deerflow.config.extensions_config import ExtensionsConfig from deerflow.config.extensions_config import ExtensionsConfig
@@ -116,6 +117,13 @@ class AppConfig(BaseModel):
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration") subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration") guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration") circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
channel_connections: ChannelConnectionsConfig = Field(
default_factory=ChannelConnectionsConfig,
description=format_field_description(
"channel_connections",
field_doc="User-facing IM channel connection configuration.",
),
)
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration") loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration") safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration")
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
@@ -0,0 +1,61 @@
"""Configuration for user-owned IM channel connections."""
from __future__ import annotations
from pydantic import BaseModel, Field
class SlackChannelConnectionConfig(BaseModel):
enabled: bool = False
@property
def configured(self) -> bool:
return True
class TelegramChannelConnectionConfig(BaseModel):
enabled: bool = False
bot_username: str = ""
@property
def configured(self) -> bool:
return bool(self.bot_username)
class DiscordChannelConnectionConfig(BaseModel):
enabled: bool = False
@property
def configured(self) -> bool:
return True
class BindingCodeChannelConnectionConfig(BaseModel):
enabled: bool = False
@property
def configured(self) -> bool:
return True
class ChannelConnectionsConfig(BaseModel):
"""Top-level config for browser-connectable IM channels."""
enabled: bool = False
slack: SlackChannelConnectionConfig = Field(default_factory=SlackChannelConnectionConfig)
telegram: TelegramChannelConnectionConfig = Field(default_factory=TelegramChannelConnectionConfig)
discord: DiscordChannelConnectionConfig = Field(default_factory=DiscordChannelConnectionConfig)
feishu: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
dingtalk: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
wechat: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
wecom: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
def provider_status(self, provider: str) -> dict[str, bool]:
config = getattr(self, provider, None)
if config is None:
return {"enabled": False, "configured": False}
enabled = bool(config.enabled)
return {
"enabled": enabled,
"configured": enabled and bool(config.configured),
}
@@ -1,4 +1,5 @@
import hashlib import hashlib
import logging
import os import os
import re import re
import shutil import shutil
@@ -14,6 +15,8 @@ _SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
_UNSAFE_USER_ID_CHAR_RE = re.compile(r"[^A-Za-z0-9_\-]") _UNSAFE_USER_ID_CHAR_RE = re.compile(r"[^A-Za-z0-9_\-]")
_SAFE_USER_ID_DIGEST_HEX_LEN = 16 _SAFE_USER_ID_DIGEST_HEX_LEN = 16
logger = logging.getLogger(__name__)
def _default_local_base_dir() -> Path: def _default_local_base_dir() -> Path:
"""Return the caller project's writable DeerFlow state directory.""" """Return the caller project's writable DeerFlow state directory."""
@@ -47,7 +50,13 @@ def make_safe_user_id(raw: str) -> str:
sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw) sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw)
if sanitized == raw: if sanitized == raw:
return raw return raw
digest = hashlib.sha1(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN] digest = hashlib.sha256(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN]
return f"{sanitized}-{digest}"
def _legacy_safe_user_id(raw: str, sanitized: str) -> str:
"""Bucket name produced by the previous (SHA-1) digest revision for ``raw``."""
digest = hashlib.sha1(raw.encode("utf-8"), usedforsecurity=False).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN]
return f"{sanitized}-{digest}" return f"{sanitized}-{digest}"
@@ -172,6 +181,32 @@ class Paths:
"""Directory for a specific user: `{base_dir}/users/{user_id}/`.""" """Directory for a specific user: `{base_dir}/users/{user_id}/`."""
return self.base_dir / "users" / _validate_user_id(user_id) return self.base_dir / "users" / _validate_user_id(user_id)
def prepare_user_dir_for_raw_id(self, raw_user_id: str) -> str:
"""Return the safe user ID and migrate this ID's legacy unsafe-id bucket.
A previous branch revision used SHA-1 for unsafe external user IDs.
New IDs use SHA-256; the legacy bucket name is recomputed from the same
raw ID, so only this user's own old bucket can ever be moved — a
different raw ID sharing the sanitized prefix produces a different
legacy digest and is never touched.
"""
safe_user_id = make_safe_user_id(raw_user_id)
sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw_user_id)
if safe_user_id == raw_user_id:
return safe_user_id
users_dir = self.base_dir / "users"
target_dir = users_dir / safe_user_id
legacy_dir = users_dir / _legacy_safe_user_id(raw_user_id, sanitized)
try:
if target_dir.exists() or not legacy_dir.is_dir():
return safe_user_id
legacy_dir.rename(target_dir)
logger.info("Migrated legacy unsafe-id user directory to the current digest format")
except OSError:
logger.exception("Failed to migrate legacy unsafe-id user directory")
return safe_user_id
def user_memory_file(self, user_id: str) -> Path: def user_memory_file(self, user_id: str) -> Path:
"""Per-user memory file: `{base_dir}/users/{user_id}/memory.json`.""" """Per-user memory file: `{base_dir}/users/{user_id}/memory.json`."""
return self.user_dir(user_id) / "memory.json" return self.user_dir(user_id) / "memory.json"
@@ -56,6 +56,9 @@ STARTUP_ONLY_FIELDS: dict[str, str] = {
# startup and the live channel clients are not rebuilt on # startup and the live channel clients are not rebuilt on
# config.yaml edits. # config.yaml edits.
"channels": ("start_channel_service() is invoked once during startup; the live IM channel clients (Feishu, Slack, Telegram, DingTalk) are not rebuilt when channels.* changes."), "channels": ("start_channel_service() is invoked once during startup; the live IM channel clients (Feishu, Slack, Telegram, DingTalk) are not rebuilt when channels.* changes."),
"channel_connections": (
"start_channel_service() wires the connection repository and channel workers once at startup, and the channel-connections router caches the merged provider config on app.state; channel_connections.* edits need a restart."
),
} }
@@ -0,0 +1,21 @@
"""User-owned IM channel connection persistence."""
from deerflow.persistence.channel_connections.model import (
ChannelConnectionRow,
ChannelConversationRow,
ChannelCredentialRow,
ChannelOAuthStateRow,
)
from deerflow.persistence.channel_connections.sql import (
ChannelConnectionRepository,
ChannelCredentialCipher,
)
__all__ = [
"ChannelConnectionRepository",
"ChannelConnectionRow",
"ChannelConversationRow",
"ChannelCredentialCipher",
"ChannelCredentialRow",
"ChannelOAuthStateRow",
]
@@ -0,0 +1,111 @@
"""ORM models for user-owned IM channel connections."""
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import JSON, DateTime, ForeignKey, Index, Integer, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
def _utc_now() -> datetime:
return datetime.now(UTC)
class ChannelConnectionRow(Base):
__tablename__ = "channel_connections"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
status: Mapped[str] = mapped_column(String(32), nullable=False, default="connected")
external_account_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
external_account_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
workspace_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
workspace_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
bot_user_id: Mapped[str | None] = mapped_column(String(128), nullable=True)
scopes_json: Mapped[list] = mapped_column(JSON, default=list)
capabilities_json: Mapped[dict] = mapped_column(JSON, default=dict)
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now)
last_seen_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
last_error_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
__table_args__ = (
UniqueConstraint(
"owner_user_id",
"provider",
"external_account_id",
"workspace_id",
name="uq_channel_connection_owner_provider_identity",
),
Index("idx_channel_connections_event_lookup", "provider", "workspace_id", "bot_user_id"),
)
class ChannelCredentialRow(Base):
__tablename__ = "channel_credentials"
connection_id: Mapped[str] = mapped_column(
String(64),
ForeignKey("channel_connections.id", ondelete="CASCADE"),
primary_key=True,
)
encrypted_access_token: Mapped[str | None] = mapped_column(Text, nullable=True)
encrypted_refresh_token: Mapped[str | None] = mapped_column(Text, nullable=True)
token_type: Mapped[str | None] = mapped_column(String(32), nullable=True)
expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
refresh_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
encrypted_extra_json: Mapped[str | None] = mapped_column(Text, nullable=True)
version: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now)
class ChannelOAuthStateRow(Base):
__tablename__ = "channel_oauth_states"
state_hash: Mapped[str] = mapped_column(String(128), primary_key=True)
owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
code_verifier_encrypted: Mapped[str | None] = mapped_column(Text, nullable=True)
nonce_hash: Mapped[str | None] = mapped_column(String(128), nullable=True)
redirect_after: Mapped[str | None] = mapped_column(Text, nullable=True)
requested_scopes_json: Mapped[list] = mapped_column(JSON, default=list)
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
consumed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now)
class ChannelConversationRow(Base):
__tablename__ = "channel_conversations"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
connection_id: Mapped[str] = mapped_column(
String(64),
ForeignKey("channel_connections.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
external_conversation_id: Mapped[str] = mapped_column(String(128), nullable=False)
external_topic_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now)
__table_args__ = (
UniqueConstraint(
"connection_id",
"external_conversation_id",
"external_topic_id",
name="uq_channel_conversation_connection_external",
),
)
@@ -0,0 +1,387 @@
"""SQL repository for user-owned IM channel connections."""
from __future__ import annotations
import base64
import hashlib
import json
import logging
import uuid
from datetime import UTC, datetime
from typing import Any
from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import delete, func, select, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.channel_connections.model import (
ChannelConnectionRow,
ChannelConversationRow,
ChannelCredentialRow,
ChannelOAuthStateRow,
)
from deerflow.utils.time import coerce_iso
logger = logging.getLogger(__name__)
class ChannelCredentialCipher:
"""Encrypts provider credentials before they are persisted."""
def __init__(self, fernet: Fernet) -> None:
self._fernet = fernet
@classmethod
def from_key(cls, key: str) -> ChannelCredentialCipher:
digest = hashlib.sha256(key.encode("utf-8")).digest()
return cls(Fernet(base64.urlsafe_b64encode(digest)))
def encrypt_text(self, value: str | None) -> str | None:
if value is None:
return None
return "fernet:v1:" + self._fernet.encrypt(value.encode("utf-8")).decode("ascii")
def decrypt_text(self, value: str | None) -> str | None:
if value is None:
return None
token = value.removeprefix("fernet:v1:")
return self._fernet.decrypt(token.encode("ascii")).decode("utf-8")
class ChannelConnectionRepository:
"""Persistence facade for channel connections, credentials, and conversations."""
def __init__(
self,
session_factory: async_sessionmaker[AsyncSession],
*,
cipher: ChannelCredentialCipher | None = None,
) -> None:
self.session_factory = session_factory
self._cipher = cipher
async def close(self) -> None:
from deerflow.persistence.engine import close_engine
await close_engine()
@staticmethod
def _new_id() -> str:
return uuid.uuid4().hex
@staticmethod
def _normalize_optional_identity(value: str | None) -> str:
return value or ""
@staticmethod
def _coerce_datetime(value: datetime | None) -> datetime | None:
if value is None or value.tzinfo is not None:
return value
return value.replace(tzinfo=UTC)
def _encrypt_optional_secret(self, value: str | None) -> str | None:
if value is None:
return None
if self._cipher is None:
raise RuntimeError("channel connection encryption key is required")
return self._cipher.encrypt_text(value)
@staticmethod
def _connection_to_dict(row: ChannelConnectionRow) -> dict[str, Any]:
data = row.to_dict()
data["external_account_id"] = data["external_account_id"] or None
data["workspace_id"] = data["workspace_id"] or None
data["scopes"] = data.pop("scopes_json") or []
data["capabilities"] = data.pop("capabilities_json") or {}
data["metadata"] = data.pop("metadata_json") or {}
for key in ("created_at", "updated_at", "last_seen_at", "last_error_at"):
value = data.get(key)
if isinstance(value, datetime):
data[key] = coerce_iso(value)
return data
async def upsert_connection(
self,
*,
owner_user_id: str,
provider: str,
external_account_id: str | None = None,
external_account_name: str | None = None,
workspace_id: str | None = None,
workspace_name: str | None = None,
bot_user_id: str | None = None,
scopes: list[str] | None = None,
capabilities: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
status: str = "connected",
) -> dict[str, Any]:
external_account_id_value = self._normalize_optional_identity(external_account_id)
workspace_id_value = self._normalize_optional_identity(workspace_id)
def _apply(row: ChannelConnectionRow) -> None:
row.status = status
row.external_account_name = external_account_name
row.workspace_name = workspace_name
row.bot_user_id = bot_user_id
row.scopes_json = list(scopes or [])
row.capabilities_json = dict(capabilities or {})
row.metadata_json = dict(metadata or {})
stmt = select(ChannelConnectionRow).where(
ChannelConnectionRow.owner_user_id == owner_user_id,
ChannelConnectionRow.provider == provider,
ChannelConnectionRow.external_account_id == external_account_id_value,
ChannelConnectionRow.workspace_id == workspace_id_value,
)
async with self.session_factory() as session:
row = (await session.execute(stmt)).scalar_one_or_none()
if row is None:
row = ChannelConnectionRow(
id=self._new_id(),
owner_user_id=owner_user_id,
provider=provider,
external_account_id=external_account_id_value,
workspace_id=workspace_id_value,
)
session.add(row)
_apply(row)
try:
await session.commit()
except IntegrityError:
# A concurrent writer inserted the same identity first; retry as
# an update of that row.
await session.rollback()
row = (await session.execute(stmt)).scalar_one()
_apply(row)
await session.commit()
await session.refresh(row)
return self._connection_to_dict(row)
async def list_connections(self, owner_user_id: str) -> list[dict[str, Any]]:
async with self.session_factory() as session:
result = await session.execute(select(ChannelConnectionRow).where(ChannelConnectionRow.owner_user_id == owner_user_id).order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc()))
return [self._connection_to_dict(row) for row in result.scalars()]
async def disconnect_connection(self, *, connection_id: str, owner_user_id: str) -> bool:
async with self.session_factory() as session:
row = await session.get(ChannelConnectionRow, connection_id)
if row is None or row.owner_user_id != owner_user_id:
return False
row.status = "revoked"
credential = await session.get(ChannelCredentialRow, connection_id)
if credential is not None:
await session.delete(credential)
await session.commit()
return True
async def store_credentials(
self,
connection_id: str,
*,
access_token: str | None,
refresh_token: str | None = None,
token_type: str | None = None,
expires_at: datetime | None = None,
refresh_expires_at: datetime | None = None,
extra: dict[str, Any] | None = None,
) -> None:
if self._cipher is None:
raise RuntimeError("channel connection encryption key is required")
async with self.session_factory() as session:
row = await session.get(ChannelCredentialRow, connection_id)
if row is None:
row = ChannelCredentialRow(connection_id=connection_id)
session.add(row)
row.encrypted_access_token = self._cipher.encrypt_text(access_token)
row.encrypted_refresh_token = self._cipher.encrypt_text(refresh_token)
row.token_type = token_type
row.expires_at = expires_at
row.refresh_expires_at = refresh_expires_at
row.encrypted_extra_json = self._cipher.encrypt_text(json.dumps(extra or {}, ensure_ascii=False))
row.version = (row.version or 0) + 1
await session.commit()
async def get_credentials(self, connection_id: str) -> dict[str, Any] | None:
if self._cipher is None:
return None
async with self.session_factory() as session:
row = await session.get(ChannelCredentialRow, connection_id)
if row is None:
return None
try:
extra_raw = self._cipher.decrypt_text(row.encrypted_extra_json)
return {
"connection_id": row.connection_id,
"access_token": self._cipher.decrypt_text(row.encrypted_access_token),
"refresh_token": self._cipher.decrypt_text(row.encrypted_refresh_token),
"token_type": row.token_type,
"expires_at": self._coerce_datetime(row.expires_at),
"refresh_expires_at": self._coerce_datetime(row.refresh_expires_at),
"extra": json.loads(extra_raw) if extra_raw else {},
}
except (InvalidToken, UnicodeError, json.JSONDecodeError):
logger.warning(
"Unable to decrypt channel connection credentials; treating credentials as unavailable",
exc_info=True,
)
return None
@staticmethod
def hash_state(state: str) -> str:
return hashlib.sha256(state.encode("utf-8")).hexdigest()
async def create_oauth_state(
self,
*,
owner_user_id: str,
provider: str,
state: str,
expires_at: datetime,
code_verifier: str | None = None,
nonce_hash: str | None = None,
redirect_after: str | None = None,
requested_scopes: list[str] | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
row = ChannelOAuthStateRow(
state_hash=self.hash_state(state),
owner_user_id=owner_user_id,
provider=provider,
code_verifier_encrypted=self._encrypt_optional_secret(code_verifier),
nonce_hash=nonce_hash,
redirect_after=redirect_after,
requested_scopes_json=list(requested_scopes or []),
metadata_json=dict(metadata or {}),
expires_at=expires_at,
)
async with self.session_factory() as session:
session.add(row)
await session.commit()
async def count_oauth_states(self, *, owner_user_id: str, provider: str) -> int:
async with self.session_factory() as session:
result = await session.execute(
select(func.count())
.select_from(ChannelOAuthStateRow)
.where(
ChannelOAuthStateRow.owner_user_id == owner_user_id,
ChannelOAuthStateRow.provider == provider,
)
)
return int(result.scalar_one())
async def consume_oauth_state(
self,
*,
provider: str,
state: str,
now: datetime | None = None,
) -> dict[str, Any] | None:
current_time = now or datetime.now(UTC)
state_hash = self.hash_state(state)
async with self.session_factory() as session:
await session.execute(delete(ChannelOAuthStateRow).where(ChannelOAuthStateRow.expires_at < current_time))
row = await session.get(ChannelOAuthStateRow, state_hash)
if row is None or row.provider != provider or row.consumed_at is not None:
await session.commit()
return None
expires_at = self._coerce_datetime(row.expires_at)
if expires_at is not None and expires_at < current_time:
await session.commit()
return None
# Conditional UPDATE so two concurrent workers cannot both consume
# the same binding code: only the writer that flips consumed_at
# from NULL wins.
result = await session.execute(
update(ChannelOAuthStateRow)
.where(
ChannelOAuthStateRow.state_hash == state_hash,
ChannelOAuthStateRow.consumed_at.is_(None),
)
.values(consumed_at=current_time)
)
await session.commit()
if result.rowcount != 1:
return None
return {
"owner_user_id": row.owner_user_id,
"provider": row.provider,
"requested_scopes": row.requested_scopes_json or [],
"metadata": row.metadata_json or {},
"redirect_after": row.redirect_after,
}
async def find_connection_by_external_identity(
self,
*,
provider: str,
external_account_id: str,
workspace_id: str | None = None,
) -> dict[str, Any] | None:
async with self.session_factory() as session:
result = await session.execute(
select(ChannelConnectionRow)
.where(
ChannelConnectionRow.provider == provider,
ChannelConnectionRow.external_account_id == self._normalize_optional_identity(external_account_id),
ChannelConnectionRow.workspace_id == self._normalize_optional_identity(workspace_id),
ChannelConnectionRow.status == "connected",
)
.order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc())
.limit(1)
)
row = result.scalar_one_or_none()
return self._connection_to_dict(row) if row is not None else None
async def set_thread_id(
self,
*,
connection_id: str,
owner_user_id: str,
provider: str,
external_conversation_id: str,
thread_id: str,
external_topic_id: str | None = None,
) -> None:
topic_id = external_topic_id or ""
async with self.session_factory() as session:
stmt = select(ChannelConversationRow).where(
ChannelConversationRow.connection_id == connection_id,
ChannelConversationRow.external_conversation_id == external_conversation_id,
ChannelConversationRow.external_topic_id == topic_id,
)
row = (await session.execute(stmt)).scalar_one_or_none()
if row is None:
row = ChannelConversationRow(
id=self._new_id(),
connection_id=connection_id,
owner_user_id=owner_user_id,
provider=provider,
external_conversation_id=external_conversation_id,
external_topic_id=topic_id,
thread_id=thread_id,
)
session.add(row)
else:
row.thread_id = thread_id
row.owner_user_id = owner_user_id
row.provider = provider
await session.commit()
async def get_thread_id(
self,
connection_id: str,
external_conversation_id: str,
external_topic_id: str | None = None,
) -> str | None:
async with self.session_factory() as session:
stmt = select(ChannelConversationRow.thread_id).where(
ChannelConversationRow.connection_id == connection_id,
ChannelConversationRow.external_conversation_id == external_conversation_id,
ChannelConversationRow.external_topic_id == (external_topic_id or ""),
)
return (await session.execute(stmt)).scalar_one_or_none()
@@ -14,10 +14,26 @@ its storage implementation lives in ``deerflow.runtime.events.store.db`` and
there is no matching entity directory. there is no matching entity directory.
""" """
from deerflow.persistence.channel_connections.model import (
ChannelConnectionRow,
ChannelConversationRow,
ChannelCredentialRow,
ChannelOAuthStateRow,
)
from deerflow.persistence.feedback.model import FeedbackRow from deerflow.persistence.feedback.model import FeedbackRow
from deerflow.persistence.models.run_event import RunEventRow from deerflow.persistence.models.run_event import RunEventRow
from deerflow.persistence.run.model import RunRow from deerflow.persistence.run.model import RunRow
from deerflow.persistence.thread_meta.model import ThreadMetaRow from deerflow.persistence.thread_meta.model import ThreadMetaRow
from deerflow.persistence.user.model import UserRow from deerflow.persistence.user.model import UserRow
__all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow", "UserRow"] __all__ = [
"ChannelConnectionRow",
"ChannelConversationRow",
"ChannelCredentialRow",
"ChannelOAuthStateRow",
"FeedbackRow",
"RunEventRow",
"RunRow",
"ThreadMetaRow",
"UserRow",
]
@@ -71,6 +71,15 @@ class ThreadMetaStore(abc.ABC):
""" """
pass pass
@abc.abstractmethod
async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
"""Move a thread metadata row to a new owner.
Intended for trusted internal repair/migration paths. No-op if the
row does not exist or the caller fails the owner check.
"""
pass
@abc.abstractmethod @abc.abstractmethod
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool: async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
"""Check if ``user_id`` has access to ``thread_id``.""" """Check if ``user_id`` has access to ``thread_id``."""
@@ -127,6 +127,14 @@ class MemoryThreadMetaStore(ThreadMetaStore):
record["updated_at"] = now_iso() record["updated_at"] = now_iso()
await self._store.aput(THREADS_NS, thread_id, record) await self._store.aput(THREADS_NS, thread_id, record)
async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_owner")
if record is None:
return
record["user_id"] = owner_user_id
record["updated_at"] = now_iso()
await self._store.aput(THREADS_NS, thread_id, record)
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete") record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete")
if record is None: if record is None:
@@ -211,6 +211,21 @@ class ThreadMetaRepository(ThreadMetaStore):
row.updated_at = datetime.now(UTC) row.updated_at = datetime.now(UTC)
await session.commit() await session.commit()
async def update_owner(
self,
thread_id: str,
owner_user_id: str,
*,
user_id: str | None | _AutoSentinel = AUTO,
) -> None:
"""Move a thread metadata row to ``owner_user_id``."""
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_owner")
async with self._sf() as session:
if not await self._check_ownership(session, thread_id, resolved_user_id):
return
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(user_id=owner_user_id, updated_at=datetime.now(UTC)))
await session.commit()
async def delete( async def delete(
self, self,
thread_id: str, thread_id: str,
@@ -83,6 +83,7 @@ class RunRecord:
multitask_strategy: str = "reject" multitask_strategy: str = "reject"
metadata: dict = field(default_factory=dict) metadata: dict = field(default_factory=dict)
kwargs: dict = field(default_factory=dict) kwargs: dict = field(default_factory=dict)
user_id: str | None = None
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
task: asyncio.Task | None = field(default=None, repr=False) task: asyncio.Task | None = field(default=None, repr=False)
@@ -124,7 +125,7 @@ class RunManager:
@staticmethod @staticmethod
def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]: def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]:
return { payload = {
"thread_id": record.thread_id, "thread_id": record.thread_id,
"assistant_id": record.assistant_id, "assistant_id": record.assistant_id,
"status": record.status.value, "status": record.status.value,
@@ -135,6 +136,9 @@ class RunManager:
"created_at": record.created_at, "created_at": record.created_at,
"model_name": record.model_name, "model_name": record.model_name,
} }
if record.user_id is not None:
payload["user_id"] = record.user_id
return payload
async def _call_store_with_retry( async def _call_store_with_retry(
self, self,
@@ -241,6 +245,7 @@ class RunManager:
kwargs=row.get("kwargs") or {}, kwargs=row.get("kwargs") or {},
created_at=row.get("created_at") or "", created_at=row.get("created_at") or "",
updated_at=row.get("updated_at") or "", updated_at=row.get("updated_at") or "",
user_id=row.get("user_id"),
error=row.get("error"), error=row.get("error"),
model_name=row.get("model_name"), model_name=row.get("model_name"),
store_only=True, store_only=True,
@@ -320,6 +325,7 @@ class RunManager:
metadata: dict | None = None, metadata: dict | None = None,
kwargs: dict | None = None, kwargs: dict | None = None,
multitask_strategy: str = "reject", multitask_strategy: str = "reject",
user_id: str | None = None,
) -> RunRecord: ) -> RunRecord:
"""Create a new pending run and register it.""" """Create a new pending run and register it."""
run_id = str(uuid.uuid4()) run_id = str(uuid.uuid4())
@@ -333,6 +339,7 @@ class RunManager:
multitask_strategy=multitask_strategy, multitask_strategy=multitask_strategy,
metadata=metadata or {}, metadata=metadata or {},
kwargs=kwargs or {}, kwargs=kwargs or {},
user_id=user_id,
created_at=now, created_at=now,
updated_at=now, updated_at=now,
) )
@@ -504,6 +511,7 @@ class RunManager:
kwargs: dict | None = None, kwargs: dict | None = None,
multitask_strategy: str = "reject", multitask_strategy: str = "reject",
model_name: str | None = None, model_name: str | None = None,
user_id: str | None = None,
) -> RunRecord: ) -> RunRecord:
"""Atomically check for inflight runs and create a new one. """Atomically check for inflight runs and create a new one.
@@ -546,6 +554,7 @@ class RunManager:
multitask_strategy=multitask_strategy, multitask_strategy=multitask_strategy,
metadata=metadata or {}, metadata=metadata or {},
kwargs=kwargs or {}, kwargs=kwargs or {},
user_id=user_id,
created_at=now, created_at=now,
updated_at=now, updated_at=now,
model_name=model_name, model_name=model_name,
+1
View File
@@ -36,6 +36,7 @@ dependencies = [
"sqlalchemy[asyncio]>=2.0,<3.0", "sqlalchemy[asyncio]>=2.0,<3.0",
"aiosqlite>=0.19", "aiosqlite>=0.19",
"alembic>=1.13", "alembic>=1.13",
"cryptography>=43.0.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]
@@ -0,0 +1,106 @@
"""Regression anchors: channel runtime-config handlers must not block the event loop.
``configure_channel_provider_runtime`` and ``disconnect_channel_provider_runtime``
persist UI-entered channel credentials through ``ChannelRuntimeConfigStore``,
whose construction reads its JSON file and whose setters rewrite it
(``json.dump`` + ``Path.replace`` + ``chmod``). The handlers offload both via
``asyncio.to_thread``; if that regresses back onto the event loop, the strict
Blockbuster gate raises ``BlockingError`` and these tests fail.
The handlers are invoked directly with a minimal Starlette ``Request`` so the
surface under test is exactly the router's own IO, mirroring
``test_agents_router``. Test-side seeding/inspection is offloaded with
``asyncio.to_thread``.
"""
from __future__ import annotations
import asyncio
import importlib
from types import SimpleNamespace
from uuid import UUID
import pytest
from fastapi import FastAPI, Request
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
from app.gateway.routers.channel_connections import (
ChannelRuntimeConfigRequest,
configure_channel_provider_runtime,
disconnect_channel_provider_runtime,
)
from deerflow.config.app_config import AppConfig, reset_app_config, set_app_config
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
# Pre-import: the handlers import this module lazily; the import's file IO
# must happen at collection time, not on the event loop under the gate.
importlib.import_module("app.channels.service")
pytestmark = pytest.mark.asyncio
@pytest.fixture(autouse=True)
def _stub_app_config():
set_app_config(AppConfig.model_validate({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}))
yield
reset_app_config()
def _make_request(tmp_path) -> Request:
app = FastAPI()
app.state.channel_connections_config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"slack": {"enabled": True},
}
)
app.state.channels_config = {}
app.state.channel_connection_repo = _FakeRepo()
store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
app.state.channel_runtime_config_store = store
user = SimpleNamespace(id=UUID("11111111-2222-3333-4444-555555555555"), system_role="admin")
return Request({"type": "http", "app": app, "headers": [], "state": {"user": user}})
class _FakeRepo:
async def list_connections(self, owner_user_id):
return []
async def test_configure_runtime_channel_does_not_block_event_loop(tmp_path) -> None:
request = await asyncio.to_thread(_make_request, tmp_path)
response = await configure_channel_provider_runtime(
"slack",
ChannelRuntimeConfigRequest(values={"bot_token": "xoxb-ui", "app_token": "xapp-ui"}),
request,
)
assert response.provider == "slack"
store = request.app.state.channel_runtime_config_store
assert await asyncio.to_thread(store.get_provider_config, "slack") == {
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
}
async def test_disconnect_runtime_channel_does_not_block_event_loop(tmp_path) -> None:
request = await asyncio.to_thread(_make_request, tmp_path)
store = request.app.state.channel_runtime_config_store
await asyncio.to_thread(
store.set_provider_config,
"slack",
{"enabled": True, "bot_token": "xoxb-ui", "app_token": "xapp-ui"},
)
request.app.state.channels_config = {
"slack": {"enabled": True, "bot_token": "xoxb-ui", "app_token": "xapp-ui"},
}
response = await disconnect_channel_provider_runtime("slack", request)
assert response.provider == "slack"
assert await asyncio.to_thread(store.get_provider_config, "slack") == {
"enabled": False,
"_runtime_disabled": True,
}
@@ -0,0 +1,251 @@
"""Connection binding tests for browser-connectable IM channels beyond Telegram/Slack/Discord."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
from app.channels.message_bus import InboundMessage, MessageBus
async def _make_repo(tmp_path, name: str):
from deerflow.persistence.channel_connections import ChannelConnectionRepository
from deerflow.persistence.engine import get_session_factory, init_engine
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / f'{name}.db'}", sqlite_dir=str(tmp_path))
return ChannelConnectionRepository(get_session_factory())
async def _seed_state(repo, provider: str, state: str, owner_user_id: str = "deerflow-user-1") -> None:
await repo.create_oauth_state(
owner_user_id=owner_user_id,
provider=provider,
state=state,
expires_at=datetime.now(UTC) + timedelta(minutes=5),
)
def test_feishu_connect_command_binds_identity(tmp_path):
import anyio
from app.channels.feishu import FeishuChannel
async def go():
repo = await _make_repo(tmp_path, "feishu")
state = "feishu-bind-code"
await _seed_state(repo, "feishu", state)
channel = FeishuChannel(
bus=MessageBus(),
config={"app_id": "app", "app_secret": "secret", "connection_repo": repo},
)
channel._reply_card = AsyncMock()
handled = await channel._bind_connection_from_connect_code(
message_id="om-message-1",
chat_id="oc-chat-1",
user_id="ou-user-1",
code=state,
)
connections = await repo.list_connections("deerflow-user-1")
assert handled is True
assert len(connections) == 1
assert connections[0]["provider"] == "feishu"
assert connections[0]["external_account_id"] == "ou-user-1"
assert connections[0]["workspace_id"] == "oc-chat-1"
channel._reply_card.assert_awaited_once_with("om-message-1", "Feishu connected to DeerFlow.")
await repo.close()
anyio.run(go)
def test_dingtalk_connect_command_binds_identity(tmp_path):
import anyio
from app.channels.dingtalk import _CONVERSATION_TYPE_GROUP, DingTalkChannel
async def go():
repo = await _make_repo(tmp_path, "dingtalk")
state = "dingtalk-bind-code"
await _seed_state(repo, "dingtalk", state)
channel = DingTalkChannel(
bus=MessageBus(),
config={"client_id": "client", "client_secret": "secret", "connection_repo": repo},
)
channel._send_connection_reply = AsyncMock()
handled = await channel._bind_connection_from_connect_code(
conversation_type=_CONVERSATION_TYPE_GROUP,
sender_staff_id="staff-user-1",
sender_nick="Alice",
conversation_id="cid-group-1",
code=state,
)
connections = await repo.list_connections("deerflow-user-1")
assert handled is True
assert len(connections) == 1
assert connections[0]["provider"] == "dingtalk"
assert connections[0]["external_account_id"] == "staff-user-1"
assert connections[0]["external_account_name"] == "Alice"
assert connections[0]["workspace_id"] == "cid-group-1"
channel._send_connection_reply.assert_awaited_once()
await repo.close()
anyio.run(go)
def test_wechat_connect_command_binds_identity(tmp_path):
import anyio
from app.channels.wechat import WechatChannel
async def go():
repo = await _make_repo(tmp_path, "wechat")
state = "wechat-bind-code"
await _seed_state(repo, "wechat", state)
channel = WechatChannel(
bus=MessageBus(),
config={"bot_token": "token", "connection_repo": repo},
)
channel._send_connection_reply = AsyncMock()
handled = await channel._bind_connection_from_connect_code(
chat_id="wx-user-1",
context_token="ctx-1",
code=state,
)
connections = await repo.list_connections("deerflow-user-1")
assert handled is True
assert len(connections) == 1
assert connections[0]["provider"] == "wechat"
assert connections[0]["external_account_id"] == "wx-user-1"
assert connections[0]["workspace_id"] == "wx-user-1"
channel._send_connection_reply.assert_awaited_once_with("wx-user-1", "ctx-1", "WeChat connected to DeerFlow.")
await repo.close()
anyio.run(go)
def test_wecom_connect_command_binds_identity(tmp_path):
import anyio
from app.channels.wecom import WeComChannel
async def go():
repo = await _make_repo(tmp_path, "wecom")
state = "wecom-bind-code"
await _seed_state(repo, "wecom", state)
channel = WeComChannel(
bus=MessageBus(),
config={"bot_id": "bot", "bot_secret": "secret", "connection_repo": repo},
)
channel._ws_client = MagicMock()
channel._ws_client.reply = AsyncMock()
frame = {"body": {"aibotid": "bot-1", "chattype": "single"}}
handled = await channel._bind_connection_from_connect_code(
frame=frame,
user_id="wecom-user-1",
code=state,
)
connections = await repo.list_connections("deerflow-user-1")
assert handled is True
assert len(connections) == 1
assert connections[0]["provider"] == "wecom"
assert connections[0]["external_account_id"] == "wecom-user-1"
assert connections[0]["workspace_id"] == "bot-1"
channel._ws_client.reply.assert_awaited_once_with(frame, {"msgtype": "text", "text": {"content": "WeCom connected to DeerFlow."}})
await repo.close()
anyio.run(go)
def test_additional_channels_attach_owner_identity(tmp_path):
import anyio
from app.channels.dingtalk import _CONVERSATION_TYPE_GROUP, DingTalkChannel
from app.channels.feishu import FeishuChannel
from app.channels.wechat import WechatChannel
from app.channels.wecom import WeComChannel
async def go():
repo = await _make_repo(tmp_path, "additional-identity")
await repo.upsert_connection(
owner_user_id="deerflow-user-1",
provider="feishu",
external_account_id="ou-user-1",
workspace_id="oc-chat-1",
)
await repo.upsert_connection(
owner_user_id="deerflow-user-1",
provider="dingtalk",
external_account_id="staff-user-1",
workspace_id="cid-group-1",
)
await repo.upsert_connection(
owner_user_id="deerflow-user-1",
provider="wechat",
external_account_id="wx-user-1",
workspace_id="wx-user-1",
)
await repo.upsert_connection(
owner_user_id="deerflow-user-1",
provider="wecom",
external_account_id="wecom-user-1",
workspace_id="bot-1",
)
cases = [
(
FeishuChannel(bus=MessageBus(), config={"connection_repo": repo}),
InboundMessage(channel_name="feishu", chat_id="oc-chat-1", user_id="ou-user-1", text="hello"),
),
(
DingTalkChannel(bus=MessageBus(), config={"connection_repo": repo}),
InboundMessage(
channel_name="dingtalk",
chat_id="cid-group-1",
user_id="staff-user-1",
text="hello",
metadata={
"conversation_type": _CONVERSATION_TYPE_GROUP,
"conversation_id": "cid-group-1",
},
),
),
(
WechatChannel(bus=MessageBus(), config={"connection_repo": repo}),
InboundMessage(channel_name="wechat", chat_id="wx-user-1", user_id="wx-user-1", text="hello"),
),
(
WeComChannel(bus=MessageBus(), config={"connection_repo": repo}),
InboundMessage(
channel_name="wecom",
chat_id="wecom-user-1",
user_id="wecom-user-1",
text="hello",
metadata={"aibotid": "bot-1"},
),
),
]
for channel, inbound in cases:
attached = await channel._attach_connection_identity(inbound)
assert attached.owner_user_id == "deerflow-user-1"
assert attached.connection_id
assert (
attached.workspace_id
== {
"feishu": "oc-chat-1",
"dingtalk": "cid-group-1",
"wechat": "wx-user-1",
"wecom": "bot-1",
}[channel.name]
)
await repo.close()
anyio.run(go)
+68
View File
@@ -280,6 +280,74 @@ def test_require_permission_denies_wrong_permission():
assert "Permission denied" in response.json()["detail"] assert "Permission denied" in response.json()["detail"]
def _make_internal_owner_check_app():
"""App with an owner_check route and a thread owned by ``alice``."""
import asyncio
from fastapi import Request
from langgraph.store.memory import InMemoryStore
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
app = FastAPI()
thread_store = MemoryThreadMetaStore(InMemoryStore())
asyncio.run(thread_store.create("alice-thread", user_id="alice"))
app.state.thread_store = thread_store
@app.get("/threads/{thread_id}")
@require_permission("threads", "read", owner_check=True)
async def endpoint(thread_id: str, request: Request):
return {"ok": True}
return app
def _internal_auth_context() -> AuthContext:
from types import SimpleNamespace
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
user = SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)
return AuthContext(user=user, permissions=[Permissions.THREADS_READ])
def test_require_permission_internal_role_scoped_by_owner_header():
"""An internal caller acting for the thread owner passes the owner check."""
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
app = _make_internal_owner_check_app()
with patch("app.gateway.authz._authenticate", return_value=_internal_auth_context()):
with TestClient(app) as client:
response = client.get(
"/threads/alice-thread",
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "alice"},
)
assert response.status_code == 200
def test_require_permission_internal_role_denied_for_other_owner():
"""The internal token must not grant access to another user's thread."""
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
app = _make_internal_owner_check_app()
with patch("app.gateway.authz._authenticate", return_value=_internal_auth_context()):
with TestClient(app) as client:
response = client.get(
"/threads/alice-thread",
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "mallory"},
)
assert response.status_code == 404
def test_require_permission_internal_role_without_header_is_scoped_to_internal_user():
"""With no owner header, internal callers are scoped like before the bypass."""
app = _make_internal_owner_check_app()
with patch("app.gateway.authz._authenticate", return_value=_internal_auth_context()):
with TestClient(app) as client:
response = client.get("/threads/alice-thread")
assert response.status_code == 404
# ── Weak JWT secret warning ────────────────────────────────────────────────── # ── Weak JWT secret warning ──────────────────────────────────────────────────
+12 -9
View File
@@ -39,6 +39,8 @@ def test_public_paths(path: str):
"/api/threads/123/uploads", "/api/threads/123/uploads",
"/api/agents", "/api/agents",
"/api/channels", "/api/channels",
"/api/channels/providers",
"/api/channels/slack/connect",
"/api/runs/stream", "/api/runs/stream",
"/api/threads/123/runs", "/api/threads/123/runs",
"/api/v1/auth/me", "/api/v1/auth/me",
@@ -183,7 +185,7 @@ def _make_auth_csrf_app():
@pytest.fixture @pytest.fixture
def client(monkeypatch): def client(monkeypatch):
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False) monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "")
return TestClient(_make_app()) return TestClient(_make_app())
@@ -221,7 +223,7 @@ def test_auth_disabled_allows_protected_path_without_cookie(monkeypatch):
assert res.json() == {"models": []} assert res.json() == {"models": []}
def test_auth_disabled_stamps_e2e_admin_user_without_cookie(monkeypatch): def test_auth_disabled_stamps_default_admin_user_without_cookie(monkeypatch):
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1") monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
client = TestClient(_make_app()) client = TestClient(_make_app())
@@ -229,10 +231,10 @@ def test_auth_disabled_stamps_e2e_admin_user_without_cookie(monkeypatch):
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == { assert res.json() == {
"id": "e2e-user", "id": "default",
"email": "e2e@test.local", "email": "default@test.local",
"system_role": "admin", "system_role": "admin",
"context_user_id": "e2e-user", "context_user_id": "default",
} }
@@ -244,8 +246,8 @@ def test_auth_disabled_auth_me_reuses_middleware_user_without_cookie(monkeypatch
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == { assert res.json() == {
"id": "e2e-user", "id": "default",
"email": "e2e@test.local", "email": "default@test.local",
"system_role": "admin", "system_role": "admin",
"needs_setup": False, "needs_setup": False,
} }
@@ -329,7 +331,7 @@ def test_auth_disabled_startup_warning_when_effective(monkeypatch, caplog):
warn_if_auth_disabled_enabled() warn_if_auth_disabled_enabled()
assert "authentication is bypassed" in caplog.text assert "authentication is bypassed" in caplog.text
assert "e2e-user" in caplog.text assert "default" in caplog.text
def test_auth_disabled_startup_warning_suppressed_in_explicit_production_env(monkeypatch, caplog): def test_auth_disabled_startup_warning_suppressed_in_explicit_production_env(monkeypatch, caplog):
@@ -348,7 +350,8 @@ def test_protected_path_with_junk_cookie_rejected(client):
"""Junk cookie → 401. Middleware strictly validates the JWT now """Junk cookie → 401. Middleware strictly validates the JWT now
(AUTH_TEST_PLAN test 7.5.8); it no longer silently passes bad (AUTH_TEST_PLAN test 7.5.8); it no longer silently passes bad
tokens through to the route handler.""" tokens through to the route handler."""
res = client.get("/api/models", cookies={"access_token": "some-token"}) client.cookies.set("access_token", "some-token")
res = client.get("/api/models")
assert res.status_code == 401 assert res.status_code == 401
@@ -0,0 +1,56 @@
"""Tests for user-facing IM channel connection configuration."""
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
def test_channel_connections_disabled_by_default():
config = ChannelConnectionsConfig()
assert config.enabled is False
assert config.slack.enabled is False
assert config.telegram.enabled is False
assert config.discord.enabled is False
assert config.feishu.enabled is False
assert config.dingtalk.enabled is False
assert config.wechat.enabled is False
assert config.wecom.enabled is False
def test_enabled_channel_connections_do_not_require_public_url_or_encryption_key():
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"telegram": {
"enabled": True,
"bot_username": "deerflow_bot",
},
"slack": {"enabled": True},
"discord": {"enabled": True},
"feishu": {"enabled": True},
"dingtalk": {"enabled": True},
"wechat": {"enabled": True},
"wecom": {"enabled": True},
}
)
assert config.enabled is True
assert config.provider_status("telegram") == {"enabled": True, "configured": True}
assert config.provider_status("slack") == {"enabled": True, "configured": True}
assert config.provider_status("discord") == {"enabled": True, "configured": True}
assert config.provider_status("feishu") == {"enabled": True, "configured": True}
assert config.provider_status("dingtalk") == {"enabled": True, "configured": True}
assert config.provider_status("wechat") == {"enabled": True, "configured": True}
assert config.provider_status("wecom") == {"enabled": True, "configured": True}
def test_provider_status_reports_disabled_and_unknown_providers():
config = ChannelConnectionsConfig.model_validate({"enabled": True})
assert config.provider_status("slack") == {"enabled": False, "configured": False}
assert config.provider_status("telegram") == {"enabled": False, "configured": False}
assert config.provider_status("discord") == {"enabled": False, "configured": False}
assert config.provider_status("feishu") == {"enabled": False, "configured": False}
assert config.provider_status("dingtalk") == {"enabled": False, "configured": False}
assert config.provider_status("wechat") == {"enabled": False, "configured": False}
assert config.provider_status("wecom") == {"enabled": False, "configured": False}
assert config.provider_status("unknown") == {"enabled": False, "configured": False}
@@ -0,0 +1,331 @@
"""Tests for per-user IM channel connection persistence."""
from __future__ import annotations
import logging
from datetime import UTC, datetime, timedelta
import pytest
from sqlalchemy import select
from deerflow.persistence.channel_connections import (
ChannelConnectionRepository,
ChannelConnectionRow,
ChannelCredentialCipher,
ChannelCredentialRow,
ChannelOAuthStateRow,
)
@pytest.fixture
async def repo(tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
url = f"sqlite+aiosqlite:///{tmp_path / 'channels.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
try:
yield ChannelConnectionRepository(
get_session_factory(),
cipher=ChannelCredentialCipher.from_key("test-encryption-key"),
)
finally:
await close_engine()
class TestChannelConnectionRepository:
@pytest.mark.anyio
async def test_connections_are_listed_per_owner(self, repo):
alice = await repo.upsert_connection(
owner_user_id="alice",
provider="slack",
external_account_id="U-alice",
external_account_name="Alice",
workspace_id="T1",
workspace_name="Team One",
scopes=["chat:write"],
)
await repo.upsert_connection(
owner_user_id="bob",
provider="slack",
external_account_id="U-bob",
external_account_name="Bob",
workspace_id="T1",
workspace_name="Team One",
scopes=["chat:write"],
)
results = await repo.list_connections("alice")
assert [item["id"] for item in results] == [alice["id"]]
assert results[0]["owner_user_id"] == "alice"
assert results[0]["provider"] == "slack"
assert results[0]["scopes"] == ["chat:write"]
assert "encrypted_access_token" not in results[0]
@pytest.mark.anyio
async def test_upsert_connection_updates_existing_provider_identity(self, repo):
first = await repo.upsert_connection(
owner_user_id="alice",
provider="telegram",
external_account_id="42",
external_account_name="Alice",
workspace_id=None,
workspace_name=None,
status="pending",
)
second = await repo.upsert_connection(
owner_user_id="alice",
provider="telegram",
external_account_id="42",
external_account_name="Alice Telegram",
workspace_id=None,
workspace_name=None,
status="connected",
)
assert second["id"] == first["id"]
assert second["status"] == "connected"
assert second["external_account_name"] == "Alice Telegram"
assert len(await repo.list_connections("alice")) == 1
@pytest.mark.anyio
async def test_credentials_are_encrypted_at_rest_and_decrypted_by_repository(self, repo):
connection = await repo.upsert_connection(
owner_user_id="alice",
provider="slack",
external_account_id="U-alice",
workspace_id="T1",
)
expires_at = datetime.now(UTC) + timedelta(hours=1)
await repo.store_credentials(
connection["id"],
access_token="xoxb-secret-access-token",
refresh_token="secret-refresh-token",
token_type="Bearer",
expires_at=expires_at,
extra={"bot_user_id": "B123"},
)
async with repo.session_factory() as session:
row = (await session.execute(select(ChannelCredentialRow))).scalar_one()
assert row.encrypted_access_token is not None
assert "xoxb-secret-access-token" not in row.encrypted_access_token
assert "secret-refresh-token" not in (row.encrypted_refresh_token or "")
assert "B123" not in (row.encrypted_extra_json or "")
credentials = await repo.get_credentials(connection["id"])
assert credentials is not None
assert credentials["access_token"] == "xoxb-secret-access-token"
assert credentials["refresh_token"] == "secret-refresh-token"
assert credentials["token_type"] == "Bearer"
assert credentials["expires_at"] == expires_at
assert credentials["extra"] == {"bot_user_id": "B123"}
@pytest.mark.anyio
async def test_get_credentials_returns_none_when_decryption_fails(self, repo, caplog):
connection = await repo.upsert_connection(
owner_user_id="alice",
provider="slack",
external_account_id="U-alice",
workspace_id="T1",
)
await repo.store_credentials(connection["id"], access_token="xoxb-secret-access-token")
wrong_key_repo = ChannelConnectionRepository(
repo.session_factory,
cipher=ChannelCredentialCipher.from_key("wrong-encryption-key"),
)
with caplog.at_level(logging.WARNING, logger="deerflow.persistence.channel_connections.sql"):
credentials = await wrong_key_repo.get_credentials(connection["id"])
assert credentials is None
assert any("Unable to decrypt channel connection credentials" in record.message for record in caplog.records)
@pytest.mark.anyio
async def test_conversations_are_scoped_by_connection(self, repo):
alice = await repo.upsert_connection(
owner_user_id="alice",
provider="slack",
external_account_id="U-alice",
workspace_id="T1",
)
bob = await repo.upsert_connection(
owner_user_id="bob",
provider="slack",
external_account_id="U-bob",
workspace_id="T1",
)
await repo.set_thread_id(
connection_id=alice["id"],
owner_user_id="alice",
provider="slack",
external_conversation_id="C-shared",
external_topic_id="1710000000.000100",
thread_id="thread-alice",
)
await repo.set_thread_id(
connection_id=bob["id"],
owner_user_id="bob",
provider="slack",
external_conversation_id="C-shared",
external_topic_id="1710000000.000100",
thread_id="thread-bob",
)
assert await repo.get_thread_id(alice["id"], "C-shared", "1710000000.000100") == "thread-alice"
assert await repo.get_thread_id(bob["id"], "C-shared", "1710000000.000100") == "thread-bob"
@pytest.mark.anyio
async def test_disconnect_connection_revokes_owner_connection_and_removes_credentials(self, repo):
connection = await repo.upsert_connection(
owner_user_id="alice",
provider="telegram",
external_account_id="42",
)
await repo.store_credentials(connection["id"], access_token="secret-token")
disconnected = await repo.disconnect_connection(
connection_id=connection["id"],
owner_user_id="alice",
)
assert disconnected is True
async with repo.session_factory() as session:
connection_row = await session.get(ChannelConnectionRow, connection["id"])
credential_row = await session.get(ChannelCredentialRow, connection["id"])
assert connection_row is not None
assert connection_row.status == "revoked"
assert credential_row is None
assert (
await repo.find_connection_by_external_identity(
provider="telegram",
external_account_id="42",
)
is None
)
@pytest.mark.anyio
async def test_disconnect_connection_is_owner_scoped(self, repo):
connection = await repo.upsert_connection(
owner_user_id="alice",
provider="telegram",
external_account_id="42",
)
disconnected = await repo.disconnect_connection(
connection_id=connection["id"],
owner_user_id="bob",
)
assert disconnected is False
assert (await repo.list_connections("alice"))[0]["status"] == "connected"
@pytest.mark.anyio
async def test_consume_oauth_state_deletes_expired_states(self, repo):
now = datetime.now(UTC)
await repo.create_oauth_state(
owner_user_id="alice",
provider="slack",
state="expired-state",
expires_at=now - timedelta(minutes=1),
)
await repo.create_oauth_state(
owner_user_id="alice",
provider="slack",
state="active-state",
expires_at=now + timedelta(minutes=5),
)
consumed = await repo.consume_oauth_state(provider="slack", state="expired-state", now=now)
assert consumed is None
async with repo.session_factory() as session:
states = (await session.execute(select(ChannelOAuthStateRow))).scalars().all()
assert [state.state_hash for state in states] == [repo.hash_state("active-state")]
@pytest.mark.anyio
async def test_consume_oauth_state_is_one_time_even_under_concurrent_consumers(self, repo):
import anyio
now = datetime.now(UTC)
await repo.create_oauth_state(
owner_user_id="alice",
provider="slack",
state="bind-once",
expires_at=now + timedelta(minutes=5),
)
results: list = []
async def consume():
results.append(await repo.consume_oauth_state(provider="slack", state="bind-once", now=now))
async with anyio.create_task_group() as tg:
tg.start_soon(consume)
tg.start_soon(consume)
consumed = [result for result in results if result is not None]
assert len(consumed) == 1
assert consumed[0]["owner_user_id"] == "alice"
@pytest.mark.anyio
async def test_upsert_connection_retries_as_update_when_concurrent_insert_wins(self, repo):
"""A losing concurrent INSERT retries as an UPDATE instead of raising IntegrityError."""
first = await repo.upsert_connection(
owner_user_id="alice",
provider="slack",
external_account_id="U-race",
workspace_id="T-race",
status="pending",
)
real_factory = repo.session_factory
class _EmptyResult:
@staticmethod
def scalar_one_or_none():
return None
class MissFirstSelectSession:
"""Make the initial identity SELECT miss, as if a concurrent writer inserted after it."""
def __init__(self, session):
self._session = session
self._missed = False
def __getattr__(self, name):
return getattr(self._session, name)
async def execute(self, *args, **kwargs):
result = await self._session.execute(*args, **kwargs)
if not self._missed:
self._missed = True
return _EmptyResult()
return result
async def __aenter__(self):
await self._session.__aenter__()
return self
async def __aexit__(self, *args):
return await self._session.__aexit__(*args)
repo.session_factory = lambda: MissFirstSelectSession(real_factory())
try:
second = await repo.upsert_connection(
owner_user_id="alice",
provider="slack",
external_account_id="U-race",
workspace_id="T-race",
status="connected",
)
finally:
repo.session_factory = real_factory
assert second["id"] == first["id"]
assert second["status"] == "connected"
connections = await repo.list_connections("alice")
assert len(connections) == 1
@@ -0,0 +1,963 @@
"""Router tests for browser-connectable IM channels."""
from __future__ import annotations
from tempfile import TemporaryDirectory
from types import SimpleNamespace
from unittest.mock import AsyncMock
from uuid import UUID
import pytest
from _router_auth_helpers import make_authed_test_app
from fastapi.testclient import TestClient
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
from app.gateway.auth.models import User
from app.gateway.routers import channel_connections
from deerflow.config.app_config import AppConfig, reset_app_config, set_app_config
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
@pytest.fixture(autouse=True)
def _stub_app_config(monkeypatch):
"""Keep router tests independent from a developer-local config.yaml."""
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "0")
set_app_config(AppConfig.model_validate({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}))
yield
reset_app_config()
def _user() -> User:
return User(
id=UUID("11111111-2222-3333-4444-555555555555"),
email="alice@example.com",
password_hash="x",
system_role="admin",
)
def _non_admin_user() -> User:
return User(
id=UUID("99999999-8888-7777-6666-555555555555"),
email="bob@example.com",
password_hash="x",
system_role="user",
)
async def _make_repo(tmp_path):
from deerflow.persistence.channel_connections import ChannelConnectionRepository
from deerflow.persistence.engine import get_session_factory, init_engine
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'router.db'}", sqlite_dir=str(tmp_path))
return ChannelConnectionRepository(get_session_factory())
def _make_app(
config: ChannelConnectionsConfig,
repo,
channels_config: dict | None = None,
*,
runtime_config_store: ChannelRuntimeConfigStore | None = None,
set_channels_config_state: bool = True,
):
app = make_authed_test_app(user_factory=_user)
app.state.channel_connections_config = config
app.state.channel_connection_repo = repo
if set_channels_config_state:
app.state.channels_config = channels_config or {}
if runtime_config_store is None:
runtime_config_dir = TemporaryDirectory()
app.state.channel_runtime_config_tmpdir = runtime_config_dir
runtime_config_store = ChannelRuntimeConfigStore(f"{runtime_config_dir.name}/runtime-config.json")
app.state.channel_runtime_config_store = runtime_config_store
app.include_router(channel_connections.router)
return app
def _enabled_connections_config() -> ChannelConnectionsConfig:
return ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"telegram": {"enabled": True, "bot_username": "deerflow_bot"},
"slack": {"enabled": True},
"discord": {"enabled": True},
"feishu": {"enabled": True},
"dingtalk": {"enabled": True},
"wechat": {"enabled": True},
"wecom": {"enabled": True},
}
)
def _channels_config() -> dict:
return {
"telegram": {"enabled": True, "bot_token": "telegram-token"},
"slack": {"enabled": True, "bot_token": "xoxb-operator", "app_token": "xapp-operator"},
"discord": {"enabled": True, "bot_token": "discord-bot"},
"feishu": {"enabled": True, "app_id": "feishu-app", "app_secret": "feishu-secret"},
"dingtalk": {"enabled": True, "client_id": "dingtalk-client", "client_secret": "dingtalk-secret"},
"wechat": {"enabled": True, "bot_token": "wechat-token"},
"wecom": {"enabled": True, "bot_id": "wecom-bot", "bot_secret": "wecom-secret"},
}
def test_get_providers_only_returns_enabled_channels_and_setup_fields(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"slack": {"enabled": True},
"discord": {"enabled": False},
}
)
app = _make_app(config, repo, {})
with TestClient(app) as client:
response = client.get("/api/channels/providers")
assert response.status_code == 200
body = response.json()
assert body["enabled"] is True
assert [provider["provider"] for provider in body["providers"]] == ["slack"]
assert body["providers"][0]["configured"] is False
assert body["providers"][0]["connectable"] is False
assert body["providers"][0]["credential_fields"] == [
{
"name": "bot_token",
"label": "Bot token",
"type": "password",
"required": True,
},
{
"name": "app_token",
"label": "App token",
"type": "password",
"required": True,
},
]
anyio.run(repo.close)
def test_get_providers_uses_existing_channels_config(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
with TestClient(app) as client:
response = client.get("/api/channels/providers")
assert response.status_code == 200
body = response.json()
assert body["enabled"] is True
by_provider = {item["provider"]: item for item in body["providers"]}
assert set(by_provider) == {"telegram", "slack", "discord", "feishu", "dingtalk", "wechat", "wecom"}
assert by_provider["telegram"]["configured"] is True
assert by_provider["telegram"]["auth_mode"] == "deep_link"
assert by_provider["telegram"]["credential_values"] == {
"bot_token": "********",
"bot_username": "deerflow_bot",
}
assert by_provider["slack"]["configured"] is True
assert by_provider["slack"]["auth_mode"] == "binding_code"
assert by_provider["slack"]["connection_status"] == "not_connected"
assert by_provider["slack"]["credential_values"] == {
"bot_token": "********",
"app_token": "********",
}
assert by_provider["discord"]["configured"] is True
assert by_provider["discord"]["auth_mode"] == "binding_code"
assert by_provider["discord"]["credential_values"] == {"bot_token": "********"}
assert by_provider["feishu"]["configured"] is True
assert by_provider["feishu"]["auth_mode"] == "binding_code"
assert by_provider["feishu"]["connection_status"] == "not_connected"
assert by_provider["feishu"]["credential_values"] == {
"app_id": "feishu-app",
"app_secret": "********",
}
assert by_provider["dingtalk"]["configured"] is True
assert by_provider["dingtalk"]["auth_mode"] == "binding_code"
assert by_provider["dingtalk"]["credential_values"] == {
"client_id": "dingtalk-client",
"client_secret": "********",
}
assert by_provider["wechat"]["configured"] is True
assert by_provider["wechat"]["auth_mode"] == "binding_code"
assert by_provider["wechat"]["credential_values"] == {"bot_token": "********"}
assert by_provider["wecom"]["configured"] is True
assert by_provider["wecom"]["auth_mode"] == "binding_code"
assert by_provider["wecom"]["credential_values"] == {
"bot_id": "wecom-bot",
"bot_secret": "********",
}
anyio.run(repo.close)
def test_get_providers_degrades_when_persistence_is_unavailable(monkeypatch):
monkeypatch.setattr(channel_connections, "get_session_factory", lambda: None)
app = _make_app(_enabled_connections_config(), None, _channels_config())
with TestClient(app) as client:
response = client.get("/api/channels/providers")
assert response.status_code == 200
by_provider = {item["provider"]: item for item in response.json()["providers"]}
assert by_provider["slack"]["configured"] is True
assert by_provider["slack"]["connectable"] is True
assert by_provider["slack"]["connection_status"] == "not_connected"
def test_get_providers_reports_connected_without_binding_in_auth_disabled_mode(tmp_path, monkeypatch):
import anyio
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
monkeypatch.delenv("DEER_FLOW_ENV", raising=False)
monkeypatch.delenv("ENVIRONMENT", raising=False)
repo = anyio.run(_make_repo, tmp_path)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
with TestClient(app) as client:
response = client.get("/api/channels/providers")
assert response.status_code == 200
by_provider = {item["provider"]: item for item in response.json()["providers"]}
# Auth-disabled local mode routes channel messages to the default user, so
# a configured running channel is effectively connected without a binding.
assert by_provider["slack"]["connection_status"] == "connected"
assert by_provider["feishu"]["connection_status"] == "connected"
anyio.run(repo.close)
def test_get_providers_reports_unconfigured_when_runtime_channel_is_missing(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
app = _make_app(_enabled_connections_config(), repo, {"telegram": {"enabled": True, "bot_token": "telegram-token"}})
with TestClient(app) as client:
response = client.get("/api/channels/providers")
assert response.status_code == 200
by_provider = {item["provider"]: item for item in response.json()["providers"]}
assert by_provider["telegram"]["configured"] is True
assert by_provider["slack"]["configured"] is False
assert by_provider["slack"]["connectable"] is False
assert "Slack credentials" in by_provider["slack"]["unavailable_reason"]
assert by_provider["discord"]["configured"] is False
assert "Discord credentials" in by_provider["discord"]["unavailable_reason"]
assert by_provider["feishu"]["configured"] is False
assert "Feishu credentials" in by_provider["feishu"]["unavailable_reason"]
assert by_provider["dingtalk"]["configured"] is False
assert "DingTalk credentials" in by_provider["dingtalk"]["unavailable_reason"]
assert by_provider["wechat"]["configured"] is False
assert "WeChat credentials" in by_provider["wechat"]["unavailable_reason"]
assert by_provider["wecom"]["configured"] is False
assert "WeCom credentials" in by_provider["wecom"]["unavailable_reason"]
anyio.run(repo.close)
def test_get_providers_reports_configured_channel_not_running(tmp_path, monkeypatch):
import anyio
repo = anyio.run(_make_repo, tmp_path)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
service = SimpleNamespace(
get_status=lambda: {
"service_running": True,
"channels": {
"feishu": {
"enabled": True,
"running": False,
}
},
}
)
monkeypatch.setattr("app.channels.service.get_channel_service", lambda: service)
with TestClient(app) as client:
response = client.get("/api/channels/providers")
assert response.status_code == 200
by_provider = {item["provider"]: item for item in response.json()["providers"]}
assert by_provider["feishu"]["configured"] is True
assert by_provider["feishu"]["connectable"] is False
assert by_provider["feishu"]["connection_status"] == "not_connected"
assert "configured but is not running" in by_provider["feishu"]["unavailable_reason"]
anyio.run(repo.close)
def test_get_providers_restarts_configured_channel_when_service_can_reconcile(tmp_path, monkeypatch):
import anyio
repo = anyio.run(_make_repo, tmp_path)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"feishu": {"enabled": True},
}
)
channels_config = {
"feishu": {
"enabled": True,
"app_id": "feishu-app",
"app_secret": "feishu-secret",
}
}
app = _make_app(config, repo, channels_config)
status = {
"service_running": True,
"channels": {
"feishu": {
"enabled": True,
"running": False,
}
},
}
reconciled: list[tuple[str, dict]] = []
async def ensure_channel_ready(provider, runtime_config):
reconciled.append((provider, dict(runtime_config)))
status["channels"][provider]["running"] = True
return True
service = SimpleNamespace(
get_status=lambda: status,
ensure_channel_ready=ensure_channel_ready,
)
monkeypatch.setattr("app.channels.service.get_channel_service", lambda: service)
with TestClient(app) as client:
response = client.get("/api/channels/providers")
assert response.status_code == 200
by_provider = {item["provider"]: item for item in response.json()["providers"]}
assert by_provider["feishu"]["configured"] is True
assert by_provider["feishu"]["connectable"] is True
assert by_provider["feishu"]["connection_status"] == "not_connected"
assert by_provider["feishu"]["unavailable_reason"] is None
assert reconciled == [("feishu", channels_config["feishu"])]
anyio.run(repo.close)
def test_get_providers_uses_newest_connection_status_per_provider(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
async def seed_connections():
await repo.upsert_connection(
owner_user_id=str(_user().id),
provider="slack",
external_account_id="U-old",
workspace_id="T-old",
status="revoked",
)
await anyio.sleep(0.01)
await repo.upsert_connection(
owner_user_id=str(_user().id),
provider="slack",
external_account_id="U-new",
workspace_id="T-new",
status="connected",
)
anyio.run(seed_connections)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
with TestClient(app) as client:
response = client.get("/api/channels/providers")
assert response.status_code == 200
by_provider = {item["provider"]: item for item in response.json()["providers"]}
assert by_provider["slack"]["connection_status"] == "connected"
anyio.run(repo.close)
def test_get_connections_returns_current_user_connections_only(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
async def seed_connections():
await repo.upsert_connection(
owner_user_id=str(_user().id),
provider="telegram",
external_account_id="42",
external_account_name="Alice",
status="connected",
)
await repo.upsert_connection(
owner_user_id="other-user",
provider="telegram",
external_account_id="99",
external_account_name="Bob",
status="connected",
)
anyio.run(seed_connections)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
with TestClient(app) as client:
response = client.get("/api/channels/connections")
assert response.status_code == 200
body = response.json()
assert len(body["connections"]) == 1
assert body["connections"][0]["provider"] == "telegram"
assert body["connections"][0]["external_account_id"] == "42"
anyio.run(repo.close)
def test_connect_telegram_returns_deep_link_and_persists_state(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
with TestClient(app) as client:
response = client.post("/api/channels/telegram/connect")
assert response.status_code == 200
body = response.json()
assert body["provider"] == "telegram"
assert body["mode"] == "deep_link"
assert body["url"].startswith("https://t.me/deerflow_bot?start=")
assert body["code"]
assert "/start" in body["instruction"]
async def count_states():
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="telegram")
assert anyio.run(count_states) == 1
anyio.run(repo.close)
def test_connect_slack_returns_binding_command_and_persists_state(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
with TestClient(app) as client:
response = client.post("/api/channels/slack/connect")
assert response.status_code == 200
body = response.json()
assert body["provider"] == "slack"
assert body["mode"] == "binding_code"
assert body["url"] is None
assert len(body["code"]) >= 22
assert body["instruction"] == f"Send /connect {body['code']} to the DeerFlow Slack bot."
async def count_states():
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="slack")
assert anyio.run(count_states) == 1
anyio.run(repo.close)
def test_connect_discord_returns_binding_command_and_persists_state(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
with TestClient(app) as client:
response = client.post("/api/channels/discord/connect")
assert response.status_code == 200
body = response.json()
assert body["provider"] == "discord"
assert body["mode"] == "binding_code"
assert body["url"] is None
assert body["code"]
assert body["instruction"] == f"Send /connect {body['code']} to the DeerFlow Discord bot."
async def count_states():
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="discord")
assert anyio.run(count_states) == 1
anyio.run(repo.close)
def test_connect_existing_binding_code_channels_return_command_and_persist_state(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
providers = ["feishu", "dingtalk", "wechat", "wecom"]
with TestClient(app) as client:
responses = {provider: client.post(f"/api/channels/{provider}/connect") for provider in providers}
for provider, response in responses.items():
expected_display_name = {
"feishu": "Feishu",
"dingtalk": "DingTalk",
"wechat": "WeChat",
"wecom": "WeCom",
}[provider]
assert response.status_code == 200
body = response.json()
assert body["provider"] == provider
assert body["mode"] == "binding_code"
assert body["url"] is None
assert len(body["code"]) >= 22
assert body["instruction"] == f"Send /connect {body['code']} to the DeerFlow {expected_display_name} bot."
async def count_states(provider=provider):
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider=provider)
assert anyio.run(count_states) == 1
anyio.run(repo.close)
def test_connect_unconfigured_runtime_channel_returns_400(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
app = _make_app(_enabled_connections_config(), repo, {})
with TestClient(app) as client:
response = client.post("/api/channels/slack/connect")
assert response.status_code == 400
assert "Slack credentials" in response.json()["detail"]
anyio.run(repo.close)
def test_configure_provider_runtime_credentials_enables_connect_without_file_edits(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"slack": {"enabled": True},
}
)
app = _make_app(config, repo, {})
with TestClient(app) as client:
configure_response = client.post(
"/api/channels/slack/runtime-config",
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
)
connect_response = client.post("/api/channels/slack/connect")
assert configure_response.status_code == 200
configured = configure_response.json()
assert configured["provider"] == "slack"
assert configured["configured"] is True
assert configured["connectable"] is True
assert configured["connection_status"] == "not_connected"
assert app.state.channels_config["slack"] == {
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
}
assert connect_response.status_code == 200
assert connect_response.json()["provider"] == "slack"
anyio.run(repo.close)
def test_runtime_config_endpoints_require_admin(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"slack": {"enabled": True},
}
)
app = make_authed_test_app(user_factory=_non_admin_user)
app.state.channel_connections_config = config
app.state.channel_connection_repo = repo
app.state.channels_config = {}
runtime_config_dir = TemporaryDirectory()
app.state.channel_runtime_config_tmpdir = runtime_config_dir
app.state.channel_runtime_config_store = ChannelRuntimeConfigStore(f"{runtime_config_dir.name}/runtime-config.json")
app.include_router(channel_connections.router)
with TestClient(app) as client:
configure_response = client.post(
"/api/channels/slack/runtime-config",
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
)
disconnect_response = client.delete("/api/channels/slack/runtime-config")
providers_response = client.get("/api/channels/providers")
assert configure_response.status_code == 403
assert "Admin privileges" in configure_response.json()["detail"]
assert disconnect_response.status_code == 403
# Read-only provider listing stays available to regular users.
assert providers_response.status_code == 200
anyio.run(repo.close)
def test_configure_telegram_runtime_uses_new_bot_username_for_deep_link_without_mutating_config(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"telegram": {"enabled": True, "bot_username": "old_bot"},
}
)
app = _make_app(config, repo, {})
with TestClient(app) as client:
configure_response = client.post(
"/api/channels/telegram/runtime-config",
json={"values": {"bot_token": "tg-token", "bot_username": "new_bot"}},
)
connect_response = client.post("/api/channels/telegram/connect")
assert configure_response.status_code == 200
assert configure_response.json()["credential_values"]["bot_username"] == "new_bot"
assert connect_response.status_code == 200
assert connect_response.json()["url"].startswith("https://t.me/new_bot?start=")
# The original config object cached by get_app_config() must stay untouched.
assert config.telegram.bot_username == "old_bot"
anyio.run(repo.close)
def test_configure_provider_runtime_credentials_survive_local_restart(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"slack": {"enabled": True},
}
)
runtime_config_path = tmp_path / "channels" / "runtime-config.json"
first_app = _make_app(
config,
repo,
{},
runtime_config_store=ChannelRuntimeConfigStore(runtime_config_path),
)
with TestClient(first_app) as client:
configure_response = client.post(
"/api/channels/slack/runtime-config",
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
)
assert configure_response.status_code == 200
restarted_app = _make_app(
config,
repo,
runtime_config_store=ChannelRuntimeConfigStore(runtime_config_path),
set_channels_config_state=False,
)
with TestClient(restarted_app) as client:
response = client.get("/api/channels/providers")
assert response.status_code == 200
by_provider = {item["provider"]: item for item in response.json()["providers"]}
assert by_provider["slack"]["configured"] is True
assert by_provider["slack"]["connectable"] is True
assert by_provider["slack"]["connection_status"] == "not_connected"
assert restarted_app.state.channels_config["slack"] == {
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
}
anyio.run(repo.close)
def test_configure_provider_runtime_credentials_preserves_masked_secrets(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"feishu": {"enabled": True},
}
)
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
app = _make_app(
config,
repo,
{
"feishu": {
"enabled": True,
"app_id": "old-app-id",
"app_secret": "old-secret",
}
},
runtime_config_store=runtime_config_store,
)
with TestClient(app) as client:
configure_response = client.post(
"/api/channels/feishu/runtime-config",
json={
"values": {
"app_id": "new-app-id",
"app_secret": "********",
}
},
)
providers_response = client.get("/api/channels/providers")
assert configure_response.status_code == 200
assert app.state.channels_config["feishu"] == {
"enabled": True,
"app_id": "new-app-id",
"app_secret": "old-secret",
}
assert runtime_config_store.get_provider_config("feishu") == {
"enabled": True,
"app_id": "new-app-id",
"app_secret": "old-secret",
}
by_provider = {item["provider"]: item for item in providers_response.json()["providers"]}
assert by_provider["feishu"]["credential_values"] == {
"app_id": "new-app-id",
"app_secret": "********",
}
anyio.run(repo.close)
def test_disconnect_provider_runtime_config_clears_connected_state(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"slack": {"enabled": True},
}
)
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
app = _make_app(config, repo, {}, runtime_config_store=runtime_config_store)
with TestClient(app) as client:
configure_response = client.post(
"/api/channels/slack/runtime-config",
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
)
disconnect_response = client.delete("/api/channels/slack/runtime-config")
providers_response = client.get("/api/channels/providers")
assert configure_response.status_code == 200
assert disconnect_response.status_code == 200
disconnected = disconnect_response.json()
assert disconnected["provider"] == "slack"
assert disconnected["configured"] is False
assert disconnected["connectable"] is False
assert disconnected["connection_status"] == "not_connected"
assert runtime_config_store.get_provider_config("slack") == {
"enabled": False,
"_runtime_disabled": True,
}
assert providers_response.status_code == 200
by_provider = {item["provider"]: item for item in providers_response.json()["providers"]}
assert by_provider["slack"]["connection_status"] == "not_connected"
anyio.run(repo.close)
def test_disconnect_provider_runtime_config_suppresses_file_config_and_stops_channel(tmp_path, monkeypatch):
import anyio
repo = anyio.run(_make_repo, tmp_path)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"feishu": {"enabled": True},
}
)
set_app_config(
AppConfig.model_validate(
{
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
"channels": {
"feishu": {
"enabled": True,
"app_id": "file-app-id",
"app_secret": "file-secret",
}
},
}
)
)
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
runtime_config_store.set_provider_config(
"feishu",
{
"enabled": True,
"app_id": "runtime-app-id",
"app_secret": "runtime-secret",
},
)
service = SimpleNamespace(
configure_channel=AsyncMock(return_value=True),
remove_channel=AsyncMock(return_value=True),
)
monkeypatch.setattr("app.channels.service.get_channel_service", lambda: service)
app = _make_app(
config,
repo,
{
"feishu": {
"enabled": True,
"app_id": "runtime-app-id",
"app_secret": "runtime-secret",
}
},
runtime_config_store=runtime_config_store,
)
with TestClient(app) as client:
disconnect_response = client.delete("/api/channels/feishu/runtime-config")
providers_response = client.get("/api/channels/providers")
assert disconnect_response.status_code == 200
disconnected = disconnect_response.json()
assert disconnected["provider"] == "feishu"
assert disconnected["configured"] is False
assert disconnected["connectable"] is False
assert disconnected["connection_status"] == "not_connected"
assert "feishu" not in app.state.channels_config
service.remove_channel.assert_awaited_once_with("feishu")
service.configure_channel.assert_not_awaited()
assert providers_response.status_code == 200
by_provider = {item["provider"]: item for item in providers_response.json()["providers"]}
assert by_provider["feishu"]["configured"] is False
assert by_provider["feishu"]["connection_status"] == "not_connected"
anyio.run(repo.close)
def test_disconnect_provider_runtime_config_revokes_current_user_provider_connections(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
async def seed_connection():
await repo.upsert_connection(
owner_user_id=str(_user().id),
provider="slack",
external_account_id="U123",
status="connected",
)
anyio.run(seed_connection)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"slack": {"enabled": True},
}
)
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
app = _make_app(config, repo, {}, runtime_config_store=runtime_config_store)
with TestClient(app) as client:
configure_response = client.post(
"/api/channels/slack/runtime-config",
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
)
disconnect_response = client.delete("/api/channels/slack/runtime-config")
assert configure_response.status_code == 200
assert disconnect_response.status_code == 200
async def get_connection_status():
return (await repo.list_connections(str(_user().id)))[0]["status"]
assert anyio.run(get_connection_status) == "revoked"
anyio.run(repo.close)
def test_disconnect_connection_revokes_current_user_connection(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
async def seed_connection():
connection = await repo.upsert_connection(
owner_user_id=str(_user().id),
provider="telegram",
external_account_id="42",
status="connected",
)
return connection["id"]
connection_id = anyio.run(seed_connection)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
with TestClient(app) as client:
response = client.delete(f"/api/channels/connections/{connection_id}")
assert response.status_code == 204
async def get_connection_status():
return (await repo.list_connections(str(_user().id)))[0]["status"]
assert anyio.run(get_connection_status) == "revoked"
anyio.run(repo.close)
def test_disconnect_connection_is_current_user_scoped(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
async def seed_connection():
connection = await repo.upsert_connection(
owner_user_id="other-user",
provider="telegram",
external_account_id="42",
status="connected",
)
return connection["id"]
connection_id = anyio.run(seed_connection)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
with TestClient(app) as client:
response = client.delete(f"/api/channels/connections/{connection_id}")
assert response.status_code == 404
async def get_connection_status():
return (await repo.list_connections("other-user"))[0]["status"]
assert anyio.run(get_connection_status) == "connected"
anyio.run(repo.close)
+565 -10
View File
@@ -487,6 +487,7 @@ def _make_mock_langgraph_client(thread_id="test-thread-123", run_result=None):
# threads.create() returns a Thread-like dict # threads.create() returns a Thread-like dict
mock_client.threads.create = AsyncMock(return_value={"thread_id": thread_id}) mock_client.threads.create = AsyncMock(return_value={"thread_id": thread_id})
mock_client.threads.update = AsyncMock(return_value={"thread_id": thread_id})
# threads.get() returns thread info (succeeds by default) # threads.get() returns thread info (succeeds by default)
mock_client.threads.get = AsyncMock(return_value={"thread_id": thread_id}) mock_client.threads.get = AsyncMock(return_value={"thread_id": thread_id})
@@ -504,6 +505,17 @@ def _make_mock_langgraph_client(thread_id="test-thread-123", run_result=None):
return mock_client return mock_client
async def _make_channel_connection_repo(tmp_path: Path):
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
from deerflow.persistence.engine import get_session_factory, init_engine
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'channel-connections.db'}", sqlite_dir=str(tmp_path))
return ChannelConnectionRepository(
get_session_factory(),
cipher=ChannelCredentialCipher.from_key("test-channel-key"),
)
def _make_stream_part(event: str, data): def _make_stream_part(event: str, data):
return SimpleNamespace(event=event, data=data) return SimpleNamespace(event=event, data=data)
@@ -656,16 +668,34 @@ class TestChannelManager:
await manager.start() await manager.start()
inbound = InboundMessage(channel_name="test", chat_id="chat1", user_id="user1", text="hi") inbound = InboundMessage(
channel_name="test",
chat_id="chat1",
user_id="user1",
text="hi",
topic_id="topic1",
thread_ts="msg1",
connection_id="conn1",
)
await bus.publish_inbound(inbound) await bus.publish_inbound(inbound)
await _wait_for(lambda: len(outbound_received) >= 1) await _wait_for(lambda: len(outbound_received) >= 1)
await manager.stop() await manager.stop()
# Thread should be created through Gateway # Thread should be created through Gateway
mock_client.threads.create.assert_called_once() mock_client.threads.create.assert_called_once()
assert mock_client.threads.create.call_args.kwargs["metadata"] == {
"channel_source": {
"type": "im_channel",
"provider": "test",
"chat_id": "chat1",
"topic_id": "topic1",
"thread_ts": "msg1",
"connection_id": "conn1",
}
}
# Thread ID should be stored # Thread ID should be stored
thread_id = store.get_thread_id("test", "chat1") thread_id = store.get_thread_id("test", "chat1", topic_id="topic1")
assert thread_id == "test-thread-123" assert thread_id == "test-thread-123"
# runs.wait should be called with the thread_id # runs.wait should be called with the thread_id
@@ -883,10 +913,12 @@ class TestChannelManager:
_run(go()) _run(go())
def test_clarification_follow_up_preserves_history(self): def test_clarification_follow_up_preserves_history(self, monkeypatch):
"""Conversation should continue after ask_clarification instead of resetting history.""" """Conversation should continue after ask_clarification instead of resetting history."""
from app.channels.manager import ChannelManager from app.channels.manager import ChannelManager
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
async def go(): async def go():
bus = MessageBus() bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
@@ -1954,10 +1986,12 @@ class TestChannelManager:
_run(go()) _run(go())
def test_same_topic_reuses_thread(self): def test_same_topic_reuses_thread(self, monkeypatch):
"""Messages with the same topic_id should reuse the same DeerFlow thread.""" """Messages with the same topic_id should reuse the same DeerFlow thread."""
from app.channels.manager import ChannelManager from app.channels.manager import ChannelManager
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
async def go(): async def go():
bus = MessageBus() bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
@@ -1990,6 +2024,17 @@ class TestChannelManager:
# threads.create should be called only ONCE (second message reuses the thread) # threads.create should be called only ONCE (second message reuses the thread)
mock_client.threads.create.assert_called_once() mock_client.threads.create.assert_called_once()
mock_client.threads.update.assert_called_once_with(
"topic-thread-1",
metadata={
"channel_source": {
"type": "im_channel",
"provider": "test",
"chat_id": "chat1",
"topic_id": "topic-root-123",
}
},
)
# Both runs.wait calls should use the same thread_id # Both runs.wait calls should use the same thread_id
assert mock_client.runs.wait.call_count == 2 assert mock_client.runs.wait.call_count == 2
@@ -2325,8 +2370,9 @@ class TestResolveRunParamsUserId:
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
return ChannelManager(bus=bus, store=store) return ChannelManager(bus=bus, store=store)
def test_safe_user_id_is_passed_through(self): def test_safe_user_id_is_passed_through(self, monkeypatch):
manager = self._manager() manager = self._manager()
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
msg = InboundMessage(channel_name="telegram", chat_id="c", user_id="123456", text="hi") msg = InboundMessage(channel_name="telegram", chat_id="c", user_id="123456", text="hi")
_, _, run_context = manager._resolve_run_params(msg, "thread-1") _, _, run_context = manager._resolve_run_params(msg, "thread-1")
@@ -2334,10 +2380,78 @@ class TestResolveRunParamsUserId:
assert run_context["user_id"] == "123456" assert run_context["user_id"] == "123456"
assert run_context["channel_user_id"] == "123456" assert run_context["channel_user_id"] == "123456"
def test_unsafe_user_id_is_normalized_but_raw_preserved(self): def test_connection_owner_user_id_takes_precedence_over_platform_user_id(self, monkeypatch):
manager = self._manager()
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
msg = InboundMessage(
channel_name="slack",
chat_id="C123",
user_id="U-platform",
owner_user_id="deerflow-user-1",
connection_id="connection-1",
text="hi",
)
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
assert run_context["user_id"] == "deerflow-user-1"
assert run_context["channel_user_id"] == "U-platform"
def test_auth_disabled_user_id_is_used_for_unbound_channel_messages(self, monkeypatch):
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
manager = self._manager()
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
msg = InboundMessage(channel_name="slack", chat_id="C123", user_id="U-platform", text="hi")
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
assert run_context["user_id"] == AUTH_DISABLED_USER_ID
assert run_context["channel_user_id"] == "U-platform"
from app.channels.manager import _owner_headers
headers = _owner_headers(msg)
assert headers is not None
assert headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == AUTH_DISABLED_USER_ID
def test_auth_disabled_user_id_overrides_bound_owner_for_local_visibility(self, monkeypatch):
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID
manager = self._manager()
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
msg = InboundMessage(
channel_name="slack",
chat_id="C123",
user_id="U-platform",
owner_user_id="real-user-from-old-binding",
text="hi",
)
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
assert run_context["user_id"] == AUTH_DISABLED_USER_ID
assert run_context["channel_user_id"] == "U-platform"
def test_unbound_channel_messages_keep_platform_user_id_when_auth_is_enabled(self, monkeypatch):
from app.channels.manager import _owner_headers
manager = self._manager()
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
msg = InboundMessage(channel_name="slack", chat_id="C123", user_id="U-platform", text="hi")
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
assert run_context["user_id"] == "U-platform"
assert run_context["channel_user_id"] == "U-platform"
assert _owner_headers(msg) is None
def test_unsafe_user_id_is_normalized_but_raw_preserved(self, monkeypatch):
from deerflow.config.paths import make_safe_user_id from deerflow.config.paths import make_safe_user_id
manager = self._manager() manager = self._manager()
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
raw = "user@example.com" raw = "user@example.com"
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi") msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi")
@@ -2347,9 +2461,32 @@ class TestResolveRunParamsUserId:
assert run_context["user_id"] != raw assert run_context["user_id"] != raw
assert run_context["channel_user_id"] == raw assert run_context["channel_user_id"] == raw
@pytest.mark.parametrize("raw_user_id", ["", None]) def test_unsafe_user_id_migrates_unique_legacy_bucket(self, tmp_path, monkeypatch):
def test_empty_or_none_user_id_is_not_injected(self, raw_user_id): from deerflow.config.paths import Paths, make_safe_user_id
paths = Paths(tmp_path)
legacy_dir = paths.base_dir / "users" / "user-example-com-63a710569261a24b"
legacy_dir.mkdir(parents=True)
(legacy_dir / "memory.json").write_text('{"legacy": true}\n', encoding="utf-8")
monkeypatch.setattr("deerflow.config.paths.get_paths", lambda: paths)
manager = self._manager() manager = self._manager()
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
raw = "user@example.com"
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi")
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
safe = make_safe_user_id(raw)
assert run_context["user_id"] == safe
assert paths.user_dir(safe).exists()
assert not legacy_dir.exists()
assert (paths.user_dir(safe) / "memory.json").read_text(encoding="utf-8") == '{"legacy": true}\n'
@pytest.mark.parametrize("raw_user_id", ["", None])
def test_empty_or_none_user_id_is_not_injected(self, raw_user_id, monkeypatch):
manager = self._manager()
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw_user_id, text="hi") msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw_user_id, text="hi")
_, _, run_context = manager._resolve_run_params(msg, "thread-1") _, _, run_context = manager._resolve_run_params(msg, "thread-1")
@@ -2358,6 +2495,93 @@ class TestResolveRunParamsUserId:
assert "channel_user_id" not in run_context assert "channel_user_id" not in run_context
class TestChannelManagerConnectionRouting:
def test_connection_scoped_conversations_do_not_share_threads(self, tmp_path, monkeypatch):
from app.channels.manager import ChannelManager
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
from deerflow.persistence.engine import close_engine
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
async def go():
repo = await _make_channel_connection_repo(tmp_path)
alice = await repo.upsert_connection(
owner_user_id="alice",
provider="slack",
external_account_id="U-alice",
workspace_id="T1",
)
bob = await repo.upsert_connection(
owner_user_id="bob",
provider="slack",
external_account_id="U-bob",
workspace_id="T1",
)
bus = MessageBus()
store = ChannelStore(path=tmp_path / "legacy-store.json")
manager = ChannelManager(bus=bus, store=store, connection_repo=repo)
mock_client = _make_mock_langgraph_client()
mock_client.threads.create = AsyncMock(
side_effect=[
{"thread_id": "thread-alice"},
{"thread_id": "thread-bob"},
]
)
manager._client = mock_client
await manager._handle_chat(
InboundMessage(
channel_name="slack",
chat_id="C-shared",
user_id="U-alice",
owner_user_id="alice",
connection_id=alice["id"],
text="hello",
thread_ts="1710000000.000100",
topic_id="1710000000.000100",
)
)
await manager._handle_chat(
InboundMessage(
channel_name="slack",
chat_id="C-shared",
user_id="U-bob",
owner_user_id="bob",
connection_id=bob["id"],
text="hello",
thread_ts="1710000000.000100",
topic_id="1710000000.000100",
)
)
assert await repo.get_thread_id(alice["id"], "C-shared", "1710000000.000100") == "thread-alice"
assert await repo.get_thread_id(bob["id"], "C-shared", "1710000000.000100") == "thread-bob"
assert store.list_entries() == []
first_context = mock_client.runs.wait.call_args_list[0].kwargs["context"]
second_context = mock_client.runs.wait.call_args_list[1].kwargs["context"]
assert first_context["user_id"] == "alice"
assert first_context["channel_user_id"] == "U-alice"
assert second_context["user_id"] == "bob"
assert second_context["channel_user_id"] == "U-bob"
first_create_headers = mock_client.threads.create.call_args_list[0].kwargs["headers"]
second_create_headers = mock_client.threads.create.call_args_list[1].kwargs["headers"]
assert first_create_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "alice"
assert second_create_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "bob"
first_run_headers = mock_client.runs.wait.call_args_list[0].kwargs["headers"]
second_run_headers = mock_client.runs.wait.call_args_list[1].kwargs["headers"]
assert first_run_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "alice"
assert second_run_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "bob"
try:
_run(go())
finally:
_run(close_engine())
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# ChannelService tests # ChannelService tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -3108,6 +3332,38 @@ class TestChannelService:
_run(go()) _run(go())
def test_concurrent_ensure_channel_ready_starts_channel_once(self):
from app.channels.service import ChannelService
async def go():
service = ChannelService(
channels_config={
"telegram": {"enabled": True, "bot_token": "tg-token"},
}
)
await service.manager.start()
service._running = True
start_calls = []
async def fake_start_channel(name, config):
start_calls.append(name)
await asyncio.sleep(0.01)
service._channels[name] = SimpleNamespace(is_running=True, stop=AsyncMock())
return True
service._start_channel = fake_start_channel
results = await asyncio.gather(
service.ensure_channel_ready("telegram"),
service.ensure_channel_ready("telegram"),
)
assert results == [True, True]
assert start_calls == ["telegram"]
await service.stop()
_run(go())
def test_session_config_is_forwarded_to_manager(self): def test_session_config_is_forwarded_to_manager(self):
from app.channels.service import ChannelService from app.channels.service import ChannelService
@@ -3175,6 +3431,226 @@ class TestChannelService:
assert service._config == {"telegram": {"enabled": False}} assert service._config == {"telegram": {"enabled": False}}
def test_from_app_config_does_not_create_runtime_channels_from_channel_connections(
self,
monkeypatch,
tmp_path,
):
from app.channels.service import ChannelService
from deerflow.config import paths as paths_module
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
monkeypatch.setattr(paths_module, "_paths", None)
app_config = SimpleNamespace(
model_extra={},
channel_connections=ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"telegram": {"enabled": True, "bot_username": "deerflow_bot"},
"slack": {"enabled": True},
"discord": {"enabled": True},
}
),
)
service = ChannelService.from_app_config(app_config)
assert service._config == {}
def test_from_app_config_preserves_existing_runtime_channels_with_channel_connections_enabled(
self,
monkeypatch,
tmp_path,
):
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
from app.channels.service import ChannelService
from deerflow.config import paths as paths_module
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
monkeypatch.setattr(paths_module, "_paths", None)
ChannelRuntimeConfigStore().set_provider_config(
"slack",
{
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
},
)
app_config = SimpleNamespace(
model_extra={
"channels": {
"telegram": {"enabled": True, "bot_token": "telegram-token"},
"slack": {"enabled": True, "bot_token": "xoxb", "app_token": "xapp"},
"discord": {"enabled": True, "bot_token": "discord-bot-token"},
}
},
channel_connections=ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"telegram": {"enabled": True, "bot_username": "deerflow_bot"},
"slack": {"enabled": True},
"discord": {"enabled": True},
}
),
)
service = ChannelService.from_app_config(app_config)
assert service._config["telegram"]["bot_token"] == "telegram-token"
assert service._config["slack"]["app_token"] == "xapp"
assert service._config["discord"]["bot_token"] == "discord-bot-token"
def test_from_app_config_loads_persisted_runtime_channel_config(self, monkeypatch, tmp_path):
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
from app.channels.service import ChannelService
from deerflow.config import paths as paths_module
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
monkeypatch.setattr(paths_module, "_paths", None)
ChannelRuntimeConfigStore().set_provider_config(
"slack",
{
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
},
)
app_config = SimpleNamespace(
model_extra={},
channel_connections=ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"slack": {"enabled": True},
}
),
)
service = ChannelService.from_app_config(app_config)
assert service._config["slack"] == {
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
}
def test_from_app_config_runtime_disconnect_suppresses_file_channel_config(self, monkeypatch, tmp_path):
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
from app.channels.service import ChannelService
from deerflow.config import paths as paths_module
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
monkeypatch.setattr(paths_module, "_paths", None)
ChannelRuntimeConfigStore().set_provider_config(
"feishu",
{
"enabled": False,
"_runtime_disabled": True,
},
)
app_config = SimpleNamespace(
model_extra={
"channels": {
"feishu": {
"enabled": True,
"app_id": "file-app-id",
"app_secret": "file-secret",
}
}
},
channel_connections=ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"feishu": {"enabled": True},
}
),
)
service = ChannelService.from_app_config(app_config)
assert "feishu" not in service._config
def test_start_retries_configured_channel_until_ready(self, monkeypatch):
from app.channels.service import ChannelService
class FlakyReadyChannel(Channel):
starts = 0
def __init__(self, bus, config):
super().__init__(name="slack", bus=bus, config=config)
async def start(self):
type(self).starts += 1
self._running = type(self).starts >= 2
async def stop(self):
self._running = False
async def send(self, msg):
return None
monkeypatch.setattr(
"deerflow.reflection.resolve_class",
lambda import_path, base_class=None: FlakyReadyChannel,
)
async def go():
service = ChannelService(
channels_config={
"slack": {
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
},
}
)
try:
await service.start()
assert FlakyReadyChannel.starts == 2
assert service.get_status()["channels"]["slack"]["running"] is True
finally:
await service.stop()
_run(go())
def test_connection_repo_is_forwarded_to_manager(self):
from app.channels.service import ChannelService
repo = object()
service = ChannelService(channels_config={}, connection_repo=repo)
assert service.manager._connection_repo is repo
def test_remove_channel_stops_running_channel_and_forgets_config(self):
from app.channels.service import ChannelService
async def go():
service = ChannelService(
channels_config={
"slack": {
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
},
}
)
channel = AsyncMock()
service._channels["slack"] = channel
service._running = True
assert await service.remove_channel("slack") is True
channel.stop.assert_awaited_once()
assert "slack" not in service._channels
assert "slack" not in service._config
_run(go())
def test_disabled_channel_with_string_creds_emits_warning(self, caplog): def test_disabled_channel_with_string_creds_emits_warning(self, caplog):
"""Warning is emitted when a channel has string credentials but enabled=false.""" """Warning is emitted when a channel has string credentials but enabled=false."""
import logging import logging
@@ -3192,7 +3668,8 @@ class TestChannelService:
await service.stop() await service.stop()
_run(go()) _run(go())
assert any("wecom" in r.message and r.levelno == logging.WARNING for r in caplog.records) assert any("credentials configured but is disabled" in r.message and r.levelno == logging.WARNING for r in caplog.records)
assert all("wecom" not in r.message for r in caplog.records)
def test_disabled_channel_with_int_creds_emits_warning(self, caplog): def test_disabled_channel_with_int_creds_emits_warning(self, caplog):
"""Warning is emitted even when YAML-parsed integer credentials are present.""" """Warning is emitted even when YAML-parsed integer credentials are present."""
@@ -3212,7 +3689,8 @@ class TestChannelService:
await service.stop() await service.stop()
_run(go()) _run(go())
assert any("telegram" in r.message and r.levelno == logging.WARNING for r in caplog.records) assert any("credentials configured but is disabled" in r.message and r.levelno == logging.WARNING for r in caplog.records)
assert all("telegram" not in r.message for r in caplog.records)
def test_disabled_channel_without_creds_emits_info(self, caplog): def test_disabled_channel_without_creds_emits_info(self, caplog):
"""Only an info log (no warning) is emitted when a channel is disabled with no credentials.""" """Only an info log (no warning) is emitted when a channel is disabled with no credentials."""
@@ -3267,6 +3745,83 @@ class TestChannelService:
assert started_configs["feishu"]["app_secret"] == "new_secret" assert started_configs["feishu"]["app_secret"] == "new_secret"
assert service._config["feishu"]["app_id"] == "new_id" assert service._config["feishu"]["app_id"] == "new_id"
def test_configure_channel_keeps_explicit_config_over_stale_file_entry(self, monkeypatch):
"""UI-entered runtime credentials must not be clobbered by a config.yaml reload.
configure_channel() receives the authoritative config (e.g. from the
browser Connect/Modify dialog, never written to config.yaml), so its
restart must skip the file reload that restart_channel() performs for
operator-triggered restarts.
"""
from app.channels.service import ChannelService
stale_file_config = {"feishu": {"enabled": True, "app_id": "file_id", "app_secret": "file_secret"}}
def mock_get_app_config():
return SimpleNamespace(model_extra={"channels": stale_file_config})
monkeypatch.setattr("deerflow.config.app_config.get_app_config", mock_get_app_config)
service = ChannelService(channels_config={})
service._running = True
started_configs = {}
async def mock_start_channel(name, config):
started_configs[name] = config
return True
service._start_channel = mock_start_channel
async def go():
await service.configure_channel("feishu", {"enabled": True, "app_id": "ui_id", "app_secret": "ui_secret"})
_run(go())
assert started_configs["feishu"]["app_id"] == "ui_id"
assert started_configs["feishu"]["app_secret"] == "ui_secret"
assert service._config["feishu"]["app_id"] == "ui_id"
def test_restart_channel_reload_applies_runtime_store_overlay(self, monkeypatch, tmp_path):
"""An operator-triggered restart keeps UI runtime-store credentials for
channels that have no config.yaml entry."""
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
from app.channels.service import ChannelService
from deerflow.config import paths as paths_module
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
monkeypatch.setattr(paths_module, "_paths", None)
ChannelRuntimeConfigStore().set_provider_config(
"telegram",
{"enabled": True, "bot_token": "store-token"},
)
def mock_get_app_config():
return SimpleNamespace(
model_extra={"channels": {}},
channel_connections=ChannelConnectionsConfig.model_validate({"enabled": True, "telegram": {"enabled": True, "bot_username": "deerflow_bot"}}),
)
monkeypatch.setattr("deerflow.config.app_config.get_app_config", mock_get_app_config)
service = ChannelService(channels_config={})
started_configs = {}
async def mock_start_channel(name, config):
started_configs[name] = config
return True
service._start_channel = mock_start_channel
async def go():
await service.restart_channel("telegram")
_run(go())
assert started_configs["telegram"]["bot_token"] == "store-token"
def test_restart_channel_falls_back_to_cached_config_on_error(self, monkeypatch): def test_restart_channel_falls_back_to_cached_config_on_error(self, monkeypatch):
"""When get_app_config() fails, restart_channel uses cached config.""" """When get_app_config() fails, restart_channel uses cached config."""
from app.channels.service import ChannelService from app.channels.service import ChannelService
+12
View File
@@ -233,3 +233,15 @@ def test_non_auth_mutation_rejects_mismatched_double_submit_token():
assert response.status_code == 403 assert response.status_code == 403
assert response.json()["detail"] == "CSRF token mismatch." assert response.json()["detail"] == "CSRF token mismatch."
def test_channel_posts_require_double_submit_csrf():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/channels/slack/connect",
headers={"Origin": "https://deerflow.example"},
)
assert response.status_code == 403
assert response.json()["detail"] == "CSRF token missing. Include X-CSRF-Token header."
@@ -0,0 +1,88 @@
"""Discord connection routing tests."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.channels.discord import DiscordChannel
from app.channels.message_bus import InboundMessage, MessageBus
@pytest.fixture
async def repo(tmp_path):
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'discord.db'}", sqlite_dir=str(tmp_path))
try:
yield ChannelConnectionRepository(
get_session_factory(),
cipher=ChannelCredentialCipher.from_key("discord-secret"),
)
finally:
await close_engine()
@pytest.mark.anyio
async def test_discord_inbound_attaches_owner_identity_from_user_level_connection(repo):
connection = await repo.upsert_connection(
owner_user_id="alice",
provider="discord",
external_account_id="987",
external_account_name="Alice",
status="connected",
)
channel = DiscordChannel(
bus=MessageBus(),
config={"bot_token": "discord-bot", "connection_repo": repo},
)
inbound = InboundMessage(
channel_name="discord",
chat_id="C123",
user_id="987",
text="hello",
)
attached = await channel._attach_connection_identity(inbound, guild_id="G123")
assert attached.connection_id == connection["id"]
assert attached.owner_user_id == "alice"
assert attached.workspace_id is None
@pytest.mark.anyio
async def test_discord_connect_command_binds_gateway_identity(repo):
state = "discord-bind-code"
await repo.create_oauth_state(
owner_user_id="deerflow-user-1",
provider="discord",
state=state,
expires_at=datetime.now(UTC) + timedelta(minutes=5),
)
channel = DiscordChannel(
bus=MessageBus(),
config={"bot_token": "discord-bot", "connection_repo": repo},
)
message = MagicMock()
message.author.id = 987
message.author.display_name = "Alice"
message.guild.id = 123
message.guild.name = "Deer Guild"
message.channel.id = 456
message.channel.send = AsyncMock()
handled = await channel._bind_connection_from_connect_code(message, state)
connections = await repo.list_connections("deerflow-user-1")
assert handled is True
assert len(connections) == 1
assert connections[0]["provider"] == "discord"
assert connections[0]["external_account_id"] == "987"
assert connections[0]["external_account_name"] == "Alice"
assert connections[0]["workspace_id"] == "123"
assert connections[0]["workspace_name"] == "Deer Guild"
assert connections[0]["metadata"]["channel_id"] == "456"
message.channel.send.assert_awaited_once()
+25
View File
@@ -73,6 +73,31 @@ def test_feishu_on_message_plain_text():
assert mock_make_inbound.call_args[1]["text"] == "Hello world" assert mock_make_inbound.call_args[1]["text"] == "Hello world"
def test_feishu_is_not_running_when_ws_thread_exits():
bus = MessageBus()
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
channel._running = True
channel._thread = MagicMock()
channel._thread.is_alive.return_value = False
assert channel.is_running is False
def test_feishu_event_handler_ignores_non_content_message_events():
import lark_oapi as lark
bus = MessageBus()
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
event_handler = channel._build_event_handler(lark)
assert "p2.im.message.receive_v1" in event_handler._processorMap
assert "p2.im.message.message_read_v1" in event_handler._processorMap
assert "p2.im.message.reaction.created_v1" in event_handler._processorMap
assert "p2.im.message.reaction.deleted_v1" in event_handler._processorMap
assert "p2.im.message.recalled_v1" in event_handler._processorMap
def test_feishu_on_message_rich_text(): def test_feishu_on_message_rich_text():
bus = MessageBus() bus = MessageBus()
config = {"app_id": "test", "app_secret": "test"} config = {"app_id": "test", "app_secret": "test"}
+95
View File
@@ -4,6 +4,18 @@ from __future__ import annotations
import json import json
import pytest
from deerflow.config.app_config import AppConfig, reset_app_config, set_app_config
@pytest.fixture
def _stub_app_config():
"""Keep run-context tests independent from a developer-local config.yaml."""
set_app_config(AppConfig.model_validate({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}))
yield
reset_app_config()
def test_format_sse_basic(): def test_format_sse_basic():
from app.gateway.services import format_sse from app.gateway.services import format_sse
@@ -36,6 +48,12 @@ def test_format_sse_no_event_id():
assert "id:" not in frame assert "id:" not in frame
def test_sanitize_log_param_strips_control_characters():
from app.gateway.utils import sanitize_log_param
assert sanitize_log_param("thread\nid\rwith\x00controls") == "threadidwithcontrols"
def test_normalize_stream_modes_none(): def test_normalize_stream_modes_none():
from app.gateway.services import normalize_stream_modes from app.gateway.services import normalize_stream_modes
@@ -474,6 +492,83 @@ def test_inject_authenticated_user_context_skips_internal_role():
assert config["context"]["user_id"] == "channel-user-7" assert config["context"]["user_id"] == "channel-user-7"
def test_start_run_uses_internal_owner_header_for_persistence(_stub_app_config):
import asyncio
from types import SimpleNamespace
from unittest.mock import patch
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.store.memory import InMemoryStore
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
from app.gateway.services import start_run
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
from deerflow.runtime import RunManager
from deerflow.runtime.runs.store.memory import MemoryRunStore
from deerflow.runtime.user_context import get_effective_user_id
async def _scenario():
run_store = MemoryRunStore()
thread_store = MemoryThreadMetaStore(InMemoryStore())
await thread_store.create("channel-thread", user_id="default", metadata={"legacy": True})
run_manager = RunManager(store=run_store)
state = SimpleNamespace(
stream_bridge=SimpleNamespace(),
run_manager=run_manager,
checkpointer=InMemorySaver(),
store=InMemoryStore(),
run_event_store=SimpleNamespace(),
run_events_config=None,
thread_store=thread_store,
)
request = SimpleNamespace(
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "owner-1"},
state=SimpleNamespace(user=SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)),
app=SimpleNamespace(state=state),
)
body = SimpleNamespace(
assistant_id="lead_agent",
input={"messages": [{"role": "human", "content": "hi"}]},
metadata={},
config=None,
context=None,
on_disconnect="cancel",
multitask_strategy="reject",
stream_mode=None,
stream_subgraphs=False,
interrupt_before=None,
interrupt_after=None,
)
task_context: dict[str, str] = {}
async def fake_run_agent(*args, **kwargs):
task_context["user_id"] = get_effective_user_id()
with (
patch("app.gateway.services.resolve_agent_factory", return_value=object()),
patch("app.gateway.services.run_agent", side_effect=fake_run_agent),
):
record = await start_run(body, "channel-thread", request)
await record.task
owner_run = await run_store.get(record.run_id, user_id="owner-1")
default_run = await run_store.get(record.run_id, user_id="default")
owner_thread = await thread_store.get("channel-thread", user_id="owner-1")
default_thread = await thread_store.get("channel-thread", user_id="default")
return owner_run, default_run, owner_thread, default_thread, task_context
owner_run, default_run, owner_thread, default_thread, task_context = asyncio.run(_scenario())
assert owner_run is not None
assert owner_run["user_id"] == "owner-1"
assert default_run is None
assert owner_thread is not None
assert owner_thread["user_id"] == "owner-1"
assert owner_thread["metadata"] == {"legacy": True}
assert default_thread is None
assert task_context["user_id"] == "owner-1"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0) # build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
+15
View File
@@ -33,3 +33,18 @@ def test_internal_auth_generates_process_local_fallback(monkeypatch):
assert reloaded.is_valid_internal_auth_token(token) is True assert reloaded.is_valid_internal_auth_token(token) is True
finally: finally:
importlib.reload(reloaded) importlib.reload(reloaded)
def test_internal_auth_headers_can_carry_owner_user_id(monkeypatch):
import app.gateway.internal_auth as internal_auth
monkeypatch.setenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", "shared-token")
reloaded = importlib.reload(internal_auth)
try:
headers = reloaded.create_internal_auth_headers(owner_user_id="owner-1")
assert headers[reloaded.INTERNAL_AUTH_HEADER_NAME] == "shared-token"
assert headers[reloaded.INTERNAL_OWNER_USER_ID_HEADER_NAME] == "owner-1"
finally:
monkeypatch.delenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", raising=False)
importlib.reload(reloaded)
@@ -44,6 +44,7 @@ class TestMakeSafeUserId:
# Sanitized prefix plus a stable digest of the original. # Sanitized prefix plus a stable digest of the original.
assert result.startswith("user-example-com-") assert result.startswith("user-example-com-")
assert len(result.rsplit("-", 1)[1]) == 16 assert len(result.rsplit("-", 1)[1]) == 16
assert result == "user-example-com-b4c9a289323b21a0"
assert make_safe_user_id("user@example.com") == result assert make_safe_user_id("user@example.com") == result
def test_sanitized_id_passes_validation(self, paths: Paths): def test_sanitized_id_passes_validation(self, paths: Paths):
@@ -69,6 +70,40 @@ class TestUserDir:
def test_user_dir(self, paths: Paths): def test_user_dir(self, paths: Paths):
assert paths.user_dir("alice") == paths.base_dir / "users" / "alice" assert paths.user_dir("alice") == paths.base_dir / "users" / "alice"
def test_prepare_user_dir_migrates_unique_legacy_unsafe_bucket(self, paths: Paths):
from deerflow.config.paths import make_safe_user_id
raw = "user@example.com"
safe = make_safe_user_id(raw)
legacy_dir = paths.base_dir / "users" / "user-example-com-63a710569261a24b"
legacy_dir.mkdir(parents=True)
(legacy_dir / "memory.json").write_text('{"legacy": true}\n', encoding="utf-8")
assert paths.prepare_user_dir_for_raw_id(raw) == safe
current_dir = paths.user_dir(safe)
assert current_dir.exists()
assert not legacy_dir.exists()
assert (current_dir / "memory.json").read_text(encoding="utf-8") == '{"legacy": true}\n'
def test_prepare_user_dir_never_migrates_another_users_bucket(self, paths: Paths):
"""A different raw ID with the same sanitized prefix has a different legacy digest."""
import hashlib
from deerflow.config.paths import make_safe_user_id
users_dir = paths.base_dir / "users"
other_legacy = users_dir / f"a-b-{hashlib.sha1(b'a/b').hexdigest()[:16]}"
other_legacy.mkdir(parents=True)
arbitrary_16_hex = users_dir / "a-b-1111111111111111"
arbitrary_16_hex.mkdir(parents=True)
assert paths.prepare_user_dir_for_raw_id("a.b") == make_safe_user_id("a.b")
assert not paths.user_dir(make_safe_user_id("a.b")).exists()
assert other_legacy.exists()
assert arbitrary_16_hex.exists()
class TestUserMemoryFile: class TestUserMemoryFile:
def test_user_memory_file(self, paths: Paths): def test_user_memory_file(self, paths: Paths):
+1
View File
@@ -90,6 +90,7 @@ def test_appconfig_descriptions_retain_original_field_documentation():
"run_events": "memory for dev", "run_events": "memory for dev",
"checkpointer": "state-persistence checkpointer", "checkpointer": "state-persistence checkpointer",
"stream_bridge": "Stream bridge", "stream_bridge": "Stream bridge",
"channel_connections": "IM channel connection",
} }
for field_name, expected_substring in descriptions.items(): for field_name, expected_substring in descriptions.items():
description = AppConfig.model_fields[field_name].description or "" description = AppConfig.model_fields[field_name].description or ""
+75
View File
@@ -7,7 +7,9 @@ Run from repo root:
from __future__ import annotations from __future__ import annotations
import yaml import yaml
from wizard import ui as wizard_ui
from wizard.providers import LLM_PROVIDERS, SEARCH_PROVIDERS, WEB_FETCH_PROVIDERS, LLMProvider from wizard.providers import LLM_PROVIDERS, SEARCH_PROVIDERS, WEB_FETCH_PROVIDERS, LLMProvider
from wizard.steps import channels as channels_step
from wizard.steps import llm as llm_step from wizard.steps import llm as llm_step
from wizard.steps import search as search_step from wizard.steps import search as search_step
from wizard.writer import ( from wizard.writer import (
@@ -327,6 +329,44 @@ class TestBuildMinimalConfig:
assert model["when_thinking_enabled"]["extra_body"]["thinking"]["type"] == "enabled" assert model["when_thinking_enabled"]["extra_body"]["thinking"]["type"] == "enabled"
assert model["when_thinking_disabled"]["extra_body"]["thinking"]["type"] == "disabled" assert model["when_thinking_disabled"]["extra_body"]["thinking"]["type"] == "disabled"
def test_can_enable_selected_channel_connections(self):
content = build_minimal_config(
provider_use="langchain_openai:ChatOpenAI",
model_name="gpt-4o",
display_name="OpenAI",
api_key_field="api_key",
env_var="OPENAI_API_KEY",
channel_connection_providers=["feishu", "slack"],
)
data = yaml.safe_load(content)
channel_connections = data["channel_connections"]
assert channel_connections["enabled"] is True
assert channel_connections["feishu"]["enabled"] is True
assert channel_connections["slack"]["enabled"] is True
assert channel_connections["telegram"]["enabled"] is False
assert channel_connections["discord"]["enabled"] is False
assert channel_connections["dingtalk"]["enabled"] is False
assert channel_connections["wechat"]["enabled"] is False
assert channel_connections["wecom"]["enabled"] is False
def test_channel_connections_disabled_when_no_channels_selected(self):
content = build_minimal_config(
provider_use="langchain_openai:ChatOpenAI",
model_name="gpt-4o",
display_name="OpenAI",
api_key_field="api_key",
env_var="OPENAI_API_KEY",
channel_connection_providers=[],
)
data = yaml.safe_load(content)
channel_connections = data["channel_connections"]
assert channel_connections["enabled"] is False
assert all(not config["enabled"] for provider, config in channel_connections.items() if provider != "enabled")
class TestLLMStep: class TestLLMStep:
def test_model_selection_defaults_to_provider_default_model(self, monkeypatch): def test_model_selection_defaults_to_provider_default_model(self, monkeypatch):
@@ -384,6 +424,41 @@ class TestLLMStep:
assert result.base_url == "https://gateway.example/v1" assert result.base_url == "https://gateway.example/v1"
class TestChannelsStep:
def test_returns_selected_channel_keys(self, monkeypatch):
monkeypatch.setattr(channels_step, "print_header", lambda *_args, **_kwargs: None)
monkeypatch.setattr(channels_step, "print_info", lambda *_args, **_kwargs: None)
monkeypatch.setattr(channels_step, "print_success", lambda *_args, **_kwargs: None)
monkeypatch.setattr(channels_step, "ask_multi_choice", lambda *_args, **_kwargs: [0, 3, 6])
result = channels_step.run_channels_step()
assert result.enabled_providers == ["telegram", "feishu", "wecom"]
def test_empty_selection_disables_channel_connections(self, monkeypatch):
monkeypatch.setattr(channels_step, "print_header", lambda *_args, **_kwargs: None)
monkeypatch.setattr(channels_step, "print_info", lambda *_args, **_kwargs: None)
monkeypatch.setattr(channels_step, "print_success", lambda *_args, **_kwargs: None)
monkeypatch.setattr(channels_step, "ask_multi_choice", lambda *_args, **_kwargs: [])
result = channels_step.run_channels_step()
assert result.enabled_providers == []
class TestWizardUi:
def test_multi_choice_blank_requires_input_without_default(self, monkeypatch):
answers = iter(["", "2"])
monkeypatch.setattr("builtins.input", lambda _prompt: next(answers))
assert wizard_ui.ask_multi_choice("Pick", ["First", "Second"], default=None) == [1]
def test_multi_choice_blank_accepts_empty_default(self, monkeypatch):
monkeypatch.setattr("builtins.input", lambda _prompt: "")
assert wizard_ui.ask_multi_choice("Pick", ["First", "Second"], default=[]) == []
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# writer.py — env file helpers # writer.py — env file helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -0,0 +1,154 @@
"""Slack connection tests for user-owned channel bindings."""
from __future__ import annotations
import sys
from datetime import UTC, datetime, timedelta
from types import ModuleType
from unittest.mock import AsyncMock, MagicMock
from app.channels.message_bus import MessageBus, OutboundMessage
async def _make_repo(tmp_path):
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
from deerflow.persistence.engine import get_session_factory, init_engine
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'slack.db'}", sqlite_dir=str(tmp_path))
return ChannelConnectionRepository(
get_session_factory(),
cipher=ChannelCredentialCipher.from_key("slack-secret"),
)
def test_slack_connect_command_binds_socket_mode_identity(tmp_path):
import anyio
from app.channels.slack import SlackChannel
async def go():
repo = await _make_repo(tmp_path)
state = "slack-bind-code"
await repo.create_oauth_state(
owner_user_id="deerflow-user-1",
provider="slack",
state=state,
expires_at=datetime.now(UTC) + timedelta(minutes=5),
)
channel = SlackChannel(
bus=MessageBus(),
config={"bot_token": "xoxb-operator", "app_token": "xapp-operator", "connection_repo": repo},
)
channel._web_client = MagicMock()
handled = await channel._bind_connection_from_connect_code(
event={
"user": "U123",
"channel": "C123",
"ts": "1710000000.000100",
},
team_id="T123",
code=state,
)
connections = await repo.list_connections("deerflow-user-1")
assert handled is True
assert len(connections) == 1
assert connections[0]["provider"] == "slack"
assert connections[0]["external_account_id"] == "U123"
assert connections[0]["workspace_id"] == "T123"
assert connections[0]["metadata"]["channel_id"] == "C123"
channel._web_client.chat_postMessage.assert_called_once()
await repo.close()
anyio.run(go)
def test_slack_send_uses_connection_bot_token_when_connection_id_is_present():
import anyio
from app.channels.slack import SlackChannel
async def go():
repo = AsyncMock()
repo.get_credentials.return_value = {"access_token": "xoxb-connection-token"}
web_client = MagicMock()
web_client_factory = MagicMock(return_value=web_client)
channel = SlackChannel(
bus=MessageBus(),
config={
"connection_repo": repo,
"web_client_factory": web_client_factory,
},
)
msg = OutboundMessage(
channel_name="slack",
chat_id="C123",
thread_id="thread-1",
text="hello",
connection_id="connection-1",
)
await channel.send(msg)
repo.get_credentials.assert_awaited_once_with("connection-1")
web_client_factory.assert_called_once_with(token="xoxb-connection-token")
web_client.chat_postMessage.assert_called_once()
anyio.run(go)
def test_slack_http_events_mode_initializes_operator_web_client(monkeypatch):
import anyio
from app.channels.slack import SlackChannel
class FakeWebClient:
def __init__(self, token: str) -> None:
self.token = token
self.messages: list[dict] = []
def auth_test(self):
return {"user_id": "B-http"}
def chat_postMessage(self, **kwargs):
self.messages.append(kwargs)
slack_sdk = ModuleType("slack_sdk")
slack_sdk.WebClient = FakeWebClient
socket_mode = ModuleType("slack_sdk.socket_mode")
socket_mode.SocketModeClient = object
response = ModuleType("slack_sdk.socket_mode.response")
response.SocketModeResponse = object
monkeypatch.setitem(sys.modules, "slack_sdk", slack_sdk)
monkeypatch.setitem(sys.modules, "slack_sdk.socket_mode", socket_mode)
monkeypatch.setitem(sys.modules, "slack_sdk.socket_mode.response", response)
async def go():
channel = SlackChannel(
bus=MessageBus(),
config={
"bot_token": "xoxb-operator",
"event_delivery": "http",
"connection_repo": MagicMock(),
},
)
await channel.start()
assert channel._running is True
assert channel._web_client is not None
assert channel._web_client.token == "xoxb-operator"
assert channel._bot_user_id == "B-http"
await channel._post_connection_reply("C123", "Slack connected to DeerFlow.", "1710000000.000100")
assert channel._web_client.messages == [
{
"channel": "C123",
"text": "Slack connected to DeerFlow.",
"thread_ts": "1710000000.000100",
}
]
await channel.stop()
anyio.run(go)
@@ -164,10 +164,42 @@ def test_stream_shared_thread_passes_owner_check():
create_or_reject.assert_awaited() create_or_reject.assert_awaited()
def test_stream_internal_role_bypasses_owner_check(): def test_stream_internal_role_scoped_by_owner_header():
"""IM channels run with the internal system role on behalf of platform """IM channels run with the internal system role on behalf of the
users whose threads they do not own the owner check must not break them.""" connection owner named in X-DeerFlow-Owner-User-Id the owner check is
scoped to that owner rather than bypassed."""
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
with _client(INTERNAL_USER) as (client, create_or_reject): with _client(INTERNAL_USER) as (client, create_or_reject):
response = client.post("/api/runs/stream", json=_body(THREAD_A)) response = client.post(
"/api/runs/stream",
json=_body(THREAD_A),
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: str(USER_A.id)},
)
assert response.status_code == 409 assert response.status_code == 409
create_or_reject.assert_awaited() create_or_reject.assert_awaited()
def test_stream_internal_role_with_foreign_owner_header_returns_404():
"""The internal token alone must not grant access to another user's thread."""
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
with _client(INTERNAL_USER) as (client, create_or_reject):
response = client.post(
"/api/runs/stream",
json=_body(THREAD_A),
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: str(USER_B.id)},
)
assert response.status_code == 404
create_or_reject.assert_not_awaited()
def test_stream_internal_role_without_owner_header_is_scoped_to_internal_user():
"""Without an owner header internal callers keep access to their own and
shared/untracked threads, but not to user-owned threads."""
with _client(INTERNAL_USER) as (client, create_or_reject):
denied = client.post("/api/runs/stream", json=_body(THREAD_A))
allowed = client.post("/api/runs/stream", json=_body(THREAD_SHARED))
assert denied.status_code == 404
assert allowed.status_code == 409
create_or_reject.assert_awaited()
@@ -0,0 +1,100 @@
"""Tests for Telegram deep-link channel connections."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.channels.message_bus import MessageBus
from app.channels.telegram import TelegramChannel
@pytest.fixture
async def repo(tmp_path: Path):
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'telegram.db'}", sqlite_dir=str(tmp_path))
try:
yield ChannelConnectionRepository(
get_session_factory(),
cipher=ChannelCredentialCipher.from_key("telegram-secret"),
)
finally:
await close_engine()
def _telegram_update(*, text: str = "/start", user_id: int = 42, chat_id: int = 100, chat_type: str = "private"):
update = MagicMock()
update.effective_user.id = user_id
update.effective_user.username = "alice"
update.effective_user.full_name = "Alice Example"
update.effective_chat.id = chat_id
update.effective_chat.type = chat_type
update.message.text = text
update.message.message_id = 55
update.message.reply_to_message = None
update.message.reply_text = AsyncMock()
return update
@pytest.mark.anyio
async def test_start_with_deep_link_state_binds_telegram_chat(repo):
state = "telegram-bind-state"
await repo.create_oauth_state(
owner_user_id="deerflow-user-1",
provider="telegram",
state=state,
expires_at=datetime.now(UTC) + timedelta(minutes=5),
)
channel = TelegramChannel(
bus=MessageBus(),
config={"bot_token": "test-token", "connection_repo": repo},
)
update = _telegram_update(text=f"/start {state}")
context = MagicMock()
context.args = [state]
await channel._cmd_start(update, context)
connections = await repo.list_connections("deerflow-user-1")
assert len(connections) == 1
assert connections[0]["provider"] == "telegram"
assert connections[0]["external_account_id"] == "42"
assert connections[0]["external_account_name"] == "Alice Example"
assert connections[0]["workspace_id"] == "100"
assert connections[0]["metadata"]["chat_type"] == "private"
update.message.reply_text.assert_awaited_once()
assert "connected" in update.message.reply_text.await_args.args[0].lower()
@pytest.mark.anyio
async def test_bound_telegram_message_publishes_connection_identity(repo):
connection = await repo.upsert_connection(
owner_user_id="deerflow-user-1",
provider="telegram",
external_account_id="42",
external_account_name="Alice Example",
workspace_id="100",
metadata={"chat_type": "private"},
)
bus = MessageBus()
channel = TelegramChannel(
bus=bus,
config={"bot_token": "test-token", "connection_repo": repo},
)
channel._main_loop = __import__("asyncio").get_event_loop()
channel._send_running_reply = AsyncMock()
await channel._on_text(_telegram_update(text="hello"), None)
inbound = await bus.get_inbound()
assert inbound.connection_id == connection["id"]
assert inbound.owner_user_id == "deerflow-user-1"
assert inbound.workspace_id == "100"
assert inbound.user_id == "42"
assert inbound.chat_id == "100"
assert inbound.text == "hello"
+13
View File
@@ -137,6 +137,19 @@ class TestThreadMetaRepository:
async def test_update_metadata_nonexistent_is_noop(self, repo): async def test_update_metadata_nonexistent_is_noop(self, repo):
await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise
@pytest.mark.anyio
async def test_update_owner_with_bypass_moves_row(self, repo):
await repo.create("t1", user_id="default", metadata={"source": "channel"})
await repo.update_owner("t1", "owner-1", user_id=None)
owner_row = await repo.get("t1", user_id="owner-1")
default_row = await repo.get("t1", user_id="default")
assert owner_row is not None
assert owner_row["user_id"] == "owner-1"
assert owner_row["metadata"] == {"source": "channel"}
assert default_row is None
# --- search with metadata filter (SQL push-down) --- # --- search with metadata filter (SQL push-down) ---
@pytest.mark.anyio @pytest.mark.anyio
+32
View File
@@ -1,4 +1,5 @@
import re import re
from types import SimpleNamespace
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@@ -218,6 +219,37 @@ def test_create_thread_returns_iso_timestamps() -> None:
assert body["created_at"] == body["updated_at"] assert body["created_at"] == body["updated_at"]
def test_internal_owner_header_assigns_thread_to_owner() -> None:
import asyncio
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
store = InMemoryStore()
checkpointer = InMemorySaver()
thread_store = MemoryThreadMetaStore(store)
request = SimpleNamespace(
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "owner-1"},
state=SimpleNamespace(user=SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)),
app=SimpleNamespace(state=SimpleNamespace(checkpointer=checkpointer, thread_store=thread_store)),
)
async def _scenario():
response = await threads.create_thread(
threads.ThreadCreateRequest(thread_id="channel-thread", metadata={}),
request,
)
owner_row = await thread_store.get("channel-thread", user_id="owner-1")
internal_row = await thread_store.get("channel-thread", user_id="default")
return response, owner_row, internal_row
response, owner_row, internal_row = asyncio.run(_scenario())
assert response.thread_id == "channel-thread"
assert owner_row is not None
assert owner_row["user_id"] == "owner-1"
assert internal_row is None
def test_get_thread_returns_iso_for_legacy_unix_record() -> None: def test_get_thread_returns_iso_for_legacy_unix_record() -> None:
"""A thread record written by older versions stores ``time.time()`` """A thread record written by older versions stores ``time.time()``
floats. ``get_thread`` must transparently surface them as ISO so the floats. ``get_thread`` must transparently surface them as ISO so the
+2
View File
@@ -820,6 +820,7 @@ dependencies = [
{ name = "agent-sandbox" }, { name = "agent-sandbox" },
{ name = "aiosqlite" }, { name = "aiosqlite" },
{ name = "alembic" }, { name = "alembic" },
{ name = "cryptography" },
{ name = "ddgs" }, { name = "ddgs" },
{ name = "dotenv" }, { name = "dotenv" },
{ name = "duckdb" }, { name = "duckdb" },
@@ -871,6 +872,7 @@ requires-dist = [
{ name = "aiosqlite", specifier = ">=0.19" }, { name = "aiosqlite", specifier = ">=0.19" },
{ name = "alembic", specifier = ">=1.13" }, { name = "alembic", specifier = ">=1.13" },
{ name = "asyncpg", marker = "extra == 'postgres'", specifier = ">=0.29" }, { name = "asyncpg", marker = "extra == 'postgres'", specifier = ">=0.29" },
{ name = "cryptography", specifier = ">=43.0.0" },
{ name = "ddgs", specifier = ">=9.10.0" }, { name = "ddgs", specifier = ">=9.10.0" },
{ name = "dotenv", specifier = ">=0.9.9" }, { name = "dotenv", specifier = ">=0.9.9" },
{ name = "duckdb", specifier = ">=1.4.4" }, { name = "duckdb", specifier = ">=1.4.4" },
+39
View File
@@ -1140,6 +1140,45 @@ run_events:
max_trace_content: 10240 max_trace_content: 10240
track_token_usage: true track_token_usage: true
# ============================================================================
# User-Owned IM Channel Connections
# ============================================================================
# Lets logged-in users connect their own IM accounts from the DeerFlow frontend
# while reusing the existing `channels` runtime configuration below.
#
# Security notes:
# - No public IP, OAuth callback URL, or provider webhook is required.
# - Provider bot/app credentials stay under `channels.*`.
# - `channel_connections` stores per-user bindings and one-time connect codes.
# - Telegram uses a deep link when `bot_username` is configured.
# - Slack, Discord, Feishu, DingTalk, WeChat, and WeCom use `/connect <code>`
# through the already-running bot/app.
#
# channel_connections:
# enabled: false
#
# telegram:
# enabled: false
# bot_username: $TELEGRAM_BOT_USERNAME
#
# slack:
# enabled: false
#
# discord:
# enabled: false
#
# feishu:
# enabled: false
#
# dingtalk:
# enabled: false
#
# wechat:
# enabled: false
#
# wecom:
# enabled: false
# ============================================================================ # ============================================================================
# IM Channels Configuration # IM Channels Configuration
# ============================================================================ # ============================================================================
+1
View File
@@ -52,6 +52,7 @@ src/
├── core/ # Core business logic ├── core/ # Core business logic
│ ├── api/ # API client & data fetching │ ├── api/ # API client & data fetching
│ ├── artifacts/ # Artifact management │ ├── artifacts/ # Artifact management
│ ├── channels/ # IM channel connections (providers, connect flow)
│ ├── config/ # App configuration │ ├── config/ # App configuration
│ ├── i18n/ # Internationalization │ ├── i18n/ # Internationalization
│ ├── mcp/ # MCP integration │ ├── mcp/ # MCP integration
+1
View File
@@ -48,6 +48,7 @@ The frontend is a stateful chat application. Users create **threads** (conversat
- `threads/` — Thread creation, streaming, state management (hooks + types) - `threads/` — Thread creation, streaming, state management (hooks + types)
- `api/` — LangGraph client singleton - `api/` — LangGraph client singleton
- `artifacts/` — Artifact loading and caching - `artifacts/` — Artifact loading and caching
- `channels/` — IM channel connections (provider catalog, connect/runtime-config API + hooks)
- `i18n/` — Internationalization (en-US, zh-CN) - `i18n/` — Internationalization (en-US, zh-CN)
- `settings/` — User preferences in localStorage - `settings/` — User preferences in localStorage
- `memory/` — Persistent user memory system - `memory/` — Persistent user memory system
+32 -14
View File
@@ -6,6 +6,10 @@ import { useEffect, useMemo, useRef, useState } from "react";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input"; import { Input } from "@/components/ui/input";
import { ScrollArea } from "@/components/ui/scroll-area"; import { ScrollArea } from "@/components/ui/scroll-area";
import {
ThreadChannelBadge,
ThreadChannelIcon,
} from "@/components/workspace/thread-channel-source";
import { import {
WorkspaceBody, WorkspaceBody,
WorkspaceContainer, WorkspaceContainer,
@@ -13,7 +17,11 @@ import {
} from "@/components/workspace/workspace-container"; } from "@/components/workspace/workspace-container";
import { useI18n } from "@/core/i18n/hooks"; import { useI18n } from "@/core/i18n/hooks";
import { useInfiniteThreads } from "@/core/threads/hooks"; import { useInfiniteThreads } from "@/core/threads/hooks";
import { pathOfThread, titleOfThread } from "@/core/threads/utils"; import {
channelSourceOfThread,
pathOfThread,
titleOfThread,
} from "@/core/threads/utils";
import { formatTimeAgo } from "@/core/utils/datetime"; import { formatTimeAgo } from "@/core/utils/datetime";
export default function ChatsPage() { export default function ChatsPage() {
@@ -82,20 +90,30 @@ export default function ChatsPage() {
<main className="min-h-0 flex-1"> <main className="min-h-0 flex-1">
<ScrollArea className="size-full py-4"> <ScrollArea className="size-full py-4">
<div className="mx-auto flex size-full max-w-(--container-width-md) flex-col"> <div className="mx-auto flex size-full max-w-(--container-width-md) flex-col">
{filteredThreads?.map((thread) => ( {filteredThreads.map((thread) => {
<Link key={thread.thread_id} href={pathOfThread(thread)}> const channelSource = channelSourceOfThread(thread);
<div className="flex flex-col gap-2 border-b p-4"> return (
<div> <Link key={thread.thread_id} href={pathOfThread(thread)}>
<div>{titleOfThread(thread)}</div> <div className="flex flex-col gap-2 border-b p-4">
</div> <div className="flex min-w-0 items-center gap-2">
{thread.updated_at && ( <ThreadChannelIcon source={channelSource} />
<div className="text-muted-foreground text-sm"> <div className="min-w-0 flex-1 truncate">
{formatTimeAgo(thread.updated_at)} {titleOfThread(thread)}
</div>
<ThreadChannelBadge
source={channelSource}
className="hidden sm:inline-flex"
/>
</div> </div>
)} {thread.updated_at && (
</div> <div className="text-muted-foreground text-sm">
</Link> {formatTimeAgo(thread.updated_at)}
))} </div>
)}
</div>
</Link>
);
})}
{hasNextPage && !isSearching && ( {hasNextPage && !isSearching && (
<div <div
ref={sentinelRef} ref={sentinelRef}
@@ -0,0 +1,184 @@
"use client";
import { MessageCircleIcon } from "lucide-react";
import type { SVGProps } from "react";
import { cn } from "@/lib/utils";
type ChannelProviderIconProps = SVGProps<SVGSVGElement> & {
provider: string;
};
export function ChannelProviderIcon({
provider,
className,
...props
}: ChannelProviderIconProps) {
const normalizedProvider = provider.toLowerCase();
if (normalizedProvider === "telegram") {
return (
<svg
viewBox="0 0 24 24"
aria-hidden="true"
className={cn("size-5", className)}
{...props}
>
<circle cx="12" cy="12" r="11" fill="#2AABEE" />
<path
fill="#FFFFFF"
d="M17.4 7.2 15.7 16c-.1.7-.5.9-1 .6l-2.8-2.1-1.4 1.3c-.1.2-.3.3-.6.3l.2-2.9 5.3-4.8c.2-.2 0-.3-.3-.1l-6.6 4.1-2.8-.9c-.6-.2-.6-.6.1-.8l10.9-4.2c.5-.2.9.1.7.7Z"
/>
</svg>
);
}
if (normalizedProvider === "slack") {
return (
<svg
viewBox="0 0 256 256"
aria-hidden="true"
className={cn("size-5", className)}
{...props}
>
<path
fill="#E01E5A"
d="M53.841 161.32c0 14.832-11.987 26.82-26.819 26.82S.203 176.152.203 161.32c0-14.831 11.987-26.818 26.82-26.818H53.84zm13.41 0c0-14.831 11.987-26.818 26.819-26.818s26.819 11.987 26.819 26.819v67.047c0 14.832-11.987 26.82-26.82 26.82c-14.83 0-26.818-11.988-26.818-26.82z"
/>
<path
fill="#36C5F0"
d="M94.07 53.638c-14.832 0-26.82-11.987-26.82-26.819S79.239 0 94.07 0s26.819 11.987 26.819 26.819v26.82zm0 13.613c14.832 0 26.819 11.987 26.819 26.819s-11.987 26.819-26.82 26.819H26.82C11.987 120.889 0 108.902 0 94.069c0-14.83 11.987-26.818 26.819-26.818z"
/>
<path
fill="#2EB67D"
d="M201.55 94.07c0-14.832 11.987-26.82 26.818-26.82s26.82 11.988 26.82 26.82s-11.988 26.819-26.82 26.819H201.55zm-13.41 0c0 14.832-11.988 26.819-26.82 26.819c-14.831 0-26.818-11.987-26.818-26.82V26.82C134.502 11.987 146.489 0 161.32 0s26.819 11.987 26.819 26.819z"
/>
<path
fill="#ECB22E"
d="M161.32 201.55c14.832 0 26.82 11.987 26.82 26.818s-11.988 26.82-26.82 26.82c-14.831 0-26.818-11.988-26.818-26.82V201.55zm0-13.41c-14.831 0-26.818-11.988-26.818-26.82c0-14.831 11.987-26.818 26.819-26.818h67.25c14.832 0 26.82 11.987 26.82 26.819s-11.988 26.819-26.82 26.819z"
/>
</svg>
);
}
if (normalizedProvider === "discord") {
return (
<svg
viewBox="0 0 24 24"
aria-hidden="true"
className={cn("size-5", className)}
{...props}
>
<circle cx="12" cy="12" r="11" fill="#5865F2" />
<path
fill="#FFFFFF"
d="M8.1 8.4c1.4-.6 2.7-.7 3.9-.7s2.5.1 3.9.7c1 1.5 1.5 3.1 1.4 4.8-.9.7-1.8 1.1-2.8 1.3l-.7-1.1c.4-.1.7-.3 1.1-.5-.3.1-.6.3-.9.4-.7.3-1.4.4-2 .4s-1.3-.1-2-.4c-.3-.1-.6-.2-.9-.4.3.2.7.4 1.1.5l-.7 1.1c-1-.2-1.9-.6-2.8-1.3-.1-1.7.4-3.3 1.4-4.8Zm2.1 3.9c.5 0 .9-.5.9-1.1s-.4-1.1-.9-1.1-.9.5-.9 1.1.4 1.1.9 1.1Zm3.6 0c.5 0 .9-.5.9-1.1s-.4-1.1-.9-1.1-.9.5-.9 1.1.4 1.1.9 1.1Z"
/>
</svg>
);
}
if (normalizedProvider === "feishu") {
return (
<svg
viewBox="0 0 24 24"
aria-hidden="true"
className={cn("size-5", className)}
{...props}
>
<rect
x="1.25"
y="1.25"
width="21.5"
height="21.5"
rx="5.25"
fill="#FFFFFF"
stroke="#E5E7EB"
strokeWidth=".5"
/>
<path
d="M6.1 4.5h8.3c.9 0 1.7.4 2.2 1.1a16 16 0 0 1 2.9 6.2c-1.8-.8-3.8-.9-5.9-.3L5.6 5.8c-.6-.5-.3-1.3.5-1.3Z"
fill="#14D6C5"
/>
<path
d="M3.2 8.9c3.6 3.4 7.5 5.7 11.7 6.8 2.7.7 5.1.4 7-.6-1.6 3.1-5.2 5.4-9.4 5.4-3.4 0-6.7-.9-9.2-2.6a2 2 0 0 1-.9-1.7V9.6c0-.7.4-1 .8-.7Z"
fill="#3370FF"
/>
<path
d="M11 14.1c2.3-3.1 6.1-4.6 10.5-3.3l.8.2-2.6 4.1a6.3 6.3 0 0 1-6 2.9c-1.9-.2-3.9-.8-5.9-1.7 1.1-.7 2.2-1.4 3.2-2.2Z"
fill="#1E3A9F"
/>
</svg>
);
}
if (normalizedProvider === "dingtalk") {
return (
<svg
viewBox="0 0 1024 1024"
aria-hidden="true"
className={cn("size-5", className)}
{...props}
>
<g transform="translate(512 512) scale(1.35) translate(-512 -512)">
<path
fill="#0B86FF"
d="M739 449.3c-1 4.2-3.5 10.4-7 17.8h.1l-.4.7c-20.3 43.1-73.1 127.7-73.1 127.7l-.3-.5-15.5 26.8h74.5L575.1 810l32.3-128h-58.6l20.4-84.7c-16.5 3.9-35.9 9.4-59 16.8 0 0-31.2 18.2-89.9-35 0 0-39.6-34.7-16.6-43.4 9.8-3.7 47.4-8.4 77-12.3 40-5.4 64.6-8.2 64.6-8.2S422 517 392.7 512.5c-29.3-4.6-66.4-53.1-74.3-95.8 0 0-12.2-23.4 26.3-12.3s197.9 43.2 197.9 43.2-207.4-63.3-221.2-78.7-40.6-84.2-37.1-126.5c0 0 1.5-10.5 12.4-7.7 0 0 153.3 69.7 258.1 107.9 104.8 37.9 195.9 57.3 184.2 106.7Z"
/>
</g>
</svg>
);
}
if (normalizedProvider === "wechat") {
return (
<svg
viewBox="0 0 24 24"
aria-hidden="true"
className={cn("size-5", className)}
{...props}
>
<circle cx="12" cy="12" r="11" fill="#07C160" />
<path
fill="#FFFFFF"
d="M10.4 6.5c-3 0-5.4 2-5.4 4.5 0 1.4.8 2.7 2.1 3.5l-.5 1.8 2-.9c.6.1 1.2.2 1.8.2 3 0 5.4-2 5.4-4.5s-2.4-4.6-5.4-4.6Zm-1.9 3.7a.7.7 0 1 1 0-1.4.7.7 0 0 1 0 1.4Zm3.7 0a.7.7 0 1 1 0-1.4.7.7 0 0 1 0 1.4Z"
/>
<path
fill="#FFFFFF"
fillOpacity=".86"
d="M14.4 12.3c2.5 0 4.6 1.7 4.6 3.8 0 1.1-.6 2.1-1.6 2.8l.4 1.5-1.7-.8c-.5.1-1.1.2-1.7.2-2.5 0-4.6-1.7-4.6-3.8s2.1-3.7 4.6-3.7Zm-1.6 3.1a.6.6 0 1 0 0-1.2.6.6 0 0 0 0 1.2Zm3.1 0a.6.6 0 1 0 0-1.2.6.6 0 0 0 0 1.2Z"
/>
</svg>
);
}
if (normalizedProvider === "wecom") {
return (
<svg
viewBox="0 0 24 24"
aria-hidden="true"
className={cn("size-5", className)}
{...props}
>
<rect
x="1.25"
y="1.25"
width="21.5"
height="21.5"
rx="5.25"
fill="#FFFFFF"
stroke="#E5E7EB"
strokeWidth=".5"
/>
<path
fill="#168DEB"
d="m17.326 8.158-.003-.007a6.6 6.6 0 0 0-1.178-1.674c-1.266-1.307-3.067-2.19-5.102-2.417a9.3 9.3 0 0 0-2.124 0h-.001c-2.061.228-3.882 1.107-5.14 2.405a6.7 6.7 0 0 0-1.194 1.682A5.7 5.7 0 0 0 2 10.657c0 1.106.332 2.218.988 3.201l.006.01c.391.594 1.092 1.39 1.637 1.83l.983.793-.208.875.527-.267.708-.358.761.225c.467.137.955.227 1.517.29h.005q.515.06 1.026.059c.355 0 .724-.02 1.095-.06a9 9 0 0 0 1.346-.258c.095.7.43 1.337.932 1.81-.658.208-1.352.358-2.061.436-.442.048-.883.072-1.312.072q-.627 0-1.253-.072a10.7 10.7 0 0 1-1.861-.36l-2.84 1.438s-.29.131-.44.131c-.418 0-.702-.285-.702-.704 0-.252.067-.598.128-.84l.394-1.653c-.728-.586-1.563-1.544-2.052-2.287A7.76 7.76 0 0 1 0 10.658a7.7 7.7 0 0 1 .787-3.39 8.7 8.7 0 0 1 1.551-2.19c1.61-1.665 3.878-2.73 6.359-3.006a11.3 11.3 0 0 1 2.565 0c2.47.275 4.712 1.353 6.323 3.017a8.6 8.6 0 0 1 1.539 2.192c.466.945.769 1.937.769 2.978a3.06 3.06 0 0 0-2-.005c-.001-.644-.189-1.329-.564-2.09zm4.125 6.977-.024-.024-.024-.018-.024-.018-.096-.095a4.24 4.24 0 0 1-1.169-2.192q0-.038-.006-.075l-.006-.056-.035-.144a1.3 1.3 0 0 0-.358-.61 1.386 1.386 0 0 0-1.957 0 1.4 1.4 0 0 0 0 1.963c.191.191.418.311.668.371.024.012.06.012.084.012q.019 0 .041.006.023.005.042.006a4.24 4.24 0 0 1 2.231 1.186c.048.048.096.095.131.143a.323.323 0 0 0 .466 0 .35.35 0 0 0 .036-.455m-1.05 4.37-.025.025c-.119.096-.31.096-.453-.036a.326.326 0 0 1 0-.467c.047-.036.094-.083.141-.13l.002-.002a4.27 4.27 0 0 0 1.187-2.28q.005-.024.006-.043c0-.024 0-.06.012-.084a1.386 1.386 0 0 1 2.326-.67 1.4 1.4 0 0 1 0 1.964c-.167.18-.382.299-.608.359l-.143.036-.057.005q-.035.006-.075.007a4.2 4.2 0 0 0-2.183 1.173l-.095.096q-.009.01-.018.024t-.018.024m-4.392-1.053.024.024.024.018q.015.009.024.018l.096.096a4.25 4.25 0 0 1 1.169 2.19q0 .04.006.076.005.03.006.057l.035.143c.06.228.18.443.358.611.537.539 1.42.539 1.957 0a1.4 1.4 0 0 0 0-1.964 1.4 1.4 0 0 0-.668-.371c-.024-.012-.06-.012-.084-.012q-.018 0-.041-.006l-.042-.006a4.25 4.25 0 0 1-2.231-1.185 1.4 1.4 0 0 1-.131-.144.323.323 0 0 0-.466 0 .325.325 0 0 0-.036.455m1.039-4.358.024-.024a.32.32 0 0 1 .453.035.326.326 0 0 1 0 .467c-.047.036-.094.083-.141.13l-.002.002a4.27 4.27 0 0 0-1.187 2.281l-.006.042c0 .024 0 .06-.012.084a1.386 1.386 0 0 1-2.326.67 1.4 1.4 0 0 1 0-1.963c.166-.18.381-.3.608-.36l.143-.035q.026 0 .056-.006.037-.005.075-.006a4.2 4.2 0 0 0 2.183-1.174l.096-.095.018-.025z"
/>
</svg>
);
}
return (
<MessageCircleIcon aria-hidden="true" className={cn("size-5", className)} />
);
}
@@ -0,0 +1,159 @@
"use client";
import { LoaderCircleIcon } from "lucide-react";
import {
type CSSProperties,
type FormEvent,
useEffect,
useMemo,
useState,
} from "react";
import { Button } from "@/components/ui/button";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import { Input } from "@/components/ui/input";
import type {
ChannelProvider,
ChannelRuntimeConfigValues,
} from "@/core/channels/types";
import { useI18n } from "@/core/i18n/hooks";
type ChannelRuntimeConfigDialogProps = {
provider: ChannelProvider | null;
open: boolean;
submitting: boolean;
onOpenChange: (open: boolean) => void;
onSubmit: (
provider: ChannelProvider,
values: ChannelRuntimeConfigValues,
) => void;
};
type SecretInputStyle = CSSProperties & {
WebkitTextSecurity?: "disc";
};
const SECRET_INPUT_STYLE: SecretInputStyle = {
WebkitTextSecurity: "disc",
};
export function ChannelRuntimeConfigDialog({
provider,
open,
submitting,
onOpenChange,
onSubmit,
}: ChannelRuntimeConfigDialogProps) {
const { t } = useI18n();
const [values, setValues] = useState<ChannelRuntimeConfigValues>({});
const fields = useMemo(
() => provider?.credential_fields ?? [],
[provider?.credential_fields],
);
const credentialValues = useMemo<ChannelRuntimeConfigValues>(
() => provider?.credential_values ?? {},
[provider?.credential_values],
);
useEffect(() => {
if (!open || !provider) {
setValues({});
return;
}
setValues(
Object.fromEntries(
fields.map((field) => [field.name, credentialValues[field.name] ?? ""]),
) as ChannelRuntimeConfigValues,
);
}, [credentialValues, fields, open, provider]);
if (!provider) {
return null;
}
const isEditing = provider.configured;
const handleSubmit = (event: FormEvent<HTMLFormElement>) => {
event.preventDefault();
onSubmit(provider, values);
};
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent>
<form onSubmit={handleSubmit} className="space-y-4">
<DialogHeader>
<DialogTitle>
{isEditing
? t.channels.setupEditTitle(provider.display_name)
: t.channels.setupTitle(provider.display_name)}
</DialogTitle>
<DialogDescription>{t.channels.setupDescription}</DialogDescription>
</DialogHeader>
<div className="space-y-3">
{fields.map((field) => {
const inputId = `channel-${provider.provider}-${field.name}`;
const isSecretField = field.type === "password";
return (
<div key={field.name} className="space-y-1.5">
<label
htmlFor={inputId}
className="text-sm leading-none font-medium"
>
{field.label}
</label>
<Input
id={inputId}
type="text"
value={values[field.name] ?? ""}
required={field.required}
autoComplete="off"
autoCorrect="off"
autoCapitalize="none"
spellCheck={false}
data-1p-ignore={isSecretField ? "true" : undefined}
data-bwignore={isSecretField ? "true" : undefined}
data-form-type={isSecretField ? "other" : undefined}
data-lpignore={isSecretField ? "true" : undefined}
style={isSecretField ? SECRET_INPUT_STYLE : undefined}
onChange={(event) => {
setValues((current) => ({
...current,
[field.name]: event.target.value,
}));
}}
/>
</div>
);
})}
</div>
<DialogFooter>
<Button
type="button"
variant="outline"
disabled={submitting}
onClick={() => onOpenChange(false)}
>
{t.common.cancel}
</Button>
<Button type="submit" disabled={submitting}>
{submitting ? (
<LoaderCircleIcon className="animate-spin" />
) : null}
{isEditing ? t.channels.saveChanges : t.channels.saveAndConnect}
</Button>
</DialogFooter>
</form>
</DialogContent>
</Dialog>
);
}
@@ -0,0 +1,213 @@
"use client";
import { CheckIcon, LoaderCircleIcon } from "lucide-react";
import { useState } from "react";
import { toast } from "sonner";
import { Button } from "@/components/ui/button";
import {
SidebarGroup,
SidebarGroupLabel,
SidebarMenu,
SidebarMenuItem,
useSidebar,
} from "@/components/ui/sidebar";
import { Skeleton } from "@/components/ui/skeleton";
import {
useConfigureChannelProvider,
useChannelProviders,
useConnectChannelProvider,
} from "@/core/channels/hooks";
import {
closeConnectWindow,
openConnectUrl,
prepareConnectWindow,
} from "@/core/channels/open-connect-url";
import {
providerCanConnect,
providerCanEditRuntimeConfig,
providerNeedsRuntimeConfig,
} from "@/core/channels/provider-state";
import type { ChannelProvider } from "@/core/channels/types";
import { useI18n } from "@/core/i18n/hooks";
import { cn } from "@/lib/utils";
import { ChannelProviderIcon } from "./channel-provider-icon";
import { ChannelRuntimeConfigDialog } from "./channel-runtime-config-dialog";
function getProviderUnavailableReason(
provider: ChannelProvider,
t: ReturnType<typeof useI18n>["t"],
): string | undefined {
if (provider.unavailable_reason) {
return provider.unavailable_reason;
}
if (!provider.enabled) {
return t.channels.disabled;
}
if (!provider.configured) {
return t.channels.unconfigured;
}
return provider.unavailable_reason ?? undefined;
}
export function WorkspaceChannelsList() {
const { open: isSidebarOpen } = useSidebar();
const { t } = useI18n();
const { enabled, providers, isLoading, error } = useChannelProviders();
const connectMutation = useConnectChannelProvider();
const configureMutation = useConfigureChannelProvider();
const [setupProvider, setSetupProvider] = useState<ChannelProvider | null>(
null,
);
const visibleProviders = providers.filter((provider) => provider.enabled);
const startConnect = (
provider: ChannelProvider,
preparedWindow?: Window | null,
) => {
const connectWindow =
preparedWindow !== undefined
? preparedWindow
: provider.auth_mode === "deep_link"
? prepareConnectWindow()
: null;
void connectMutation
.mutateAsync(provider.provider)
.then((result) => {
if (result.url) {
openConnectUrl(result.url, connectWindow);
return;
}
closeConnectWindow(connectWindow);
toast.success(result.instruction);
})
.catch((error) => {
closeConnectWindow(connectWindow);
toast.error(
error instanceof Error ? error.message : t.channels.unavailable,
);
});
};
if (!isSidebarOpen) {
return null;
}
if (isLoading) {
return (
<SidebarGroup className="pt-0">
<SidebarGroupLabel>{t.sidebar.channels}</SidebarGroupLabel>
<div className="space-y-2 px-2 py-1">
<Skeleton className="h-8 w-full" />
<Skeleton className="h-8 w-full" />
<Skeleton className="h-8 w-full" />
</div>
</SidebarGroup>
);
}
if (error || !enabled || visibleProviders.length === 0) {
return null;
}
return (
<SidebarGroup className="pt-0">
<SidebarGroupLabel>{t.sidebar.channels}</SidebarGroupLabel>
<SidebarMenu>
{visibleProviders.map((provider) => {
const canEditRuntimeConfig = providerCanEditRuntimeConfig(provider);
const isConnected = provider.connection_status === "connected";
const isPending =
(connectMutation.isPending &&
connectMutation.variables === provider.provider) ||
(configureMutation.isPending &&
configureMutation.variables?.provider === provider.provider);
const canConnect = providerCanConnect(provider);
const unavailableReason = getProviderUnavailableReason(provider, t);
return (
<SidebarMenuItem key={provider.provider}>
<div className="hover:bg-sidebar-accent flex h-10 items-center gap-2 rounded-md px-2 transition-colors">
<ChannelProviderIcon
provider={provider.provider}
className="size-5 shrink-0"
/>
<span className="min-w-0 flex-1 truncate text-sm font-medium">
{provider.display_name}
</span>
<Button
type="button"
size="sm"
variant={isConnected ? "outline" : "secondary"}
className={cn(
"h-8 w-24 px-2 text-xs",
isConnected && "gap-1",
)}
disabled={isPending}
title={unavailableReason}
onClick={() => {
if (
providerNeedsRuntimeConfig(provider) ||
(isConnected && canEditRuntimeConfig)
) {
setSetupProvider(provider);
return;
}
if (!canConnect) {
toast.error(unavailableReason ?? t.channels.unavailable);
return;
}
startConnect(provider);
}}
>
{isPending ? (
<LoaderCircleIcon className="size-3.5 animate-spin" />
) : isConnected ? (
<CheckIcon className="size-3.5" />
) : null}
<span>
{isConnected ? t.channels.connected : t.channels.connect}
</span>
</Button>
</div>
</SidebarMenuItem>
);
})}
</SidebarMenu>
<ChannelRuntimeConfigDialog
provider={setupProvider}
open={setupProvider !== null}
submitting={configureMutation.isPending}
onOpenChange={(open) => {
if (!open) {
setSetupProvider(null);
}
}}
onSubmit={(provider, values) => {
const connectWindow =
provider.auth_mode === "deep_link" ? prepareConnectWindow() : null;
void configureMutation
.mutateAsync({ provider: provider.provider, values })
.then((updated) => {
setSetupProvider(null);
if (providerCanConnect(updated)) {
startConnect(updated, connectWindow);
return;
}
closeConnectWindow(connectWindow);
toast.success(t.channels.connected);
})
.catch((error) => {
closeConnectWindow(connectWindow);
toast.error(
error instanceof Error ? error.message : t.channels.unavailable,
);
});
}}
/>
</SidebarGroup>
);
}
@@ -55,10 +55,16 @@ import {
useRenameThread, useRenameThread,
} from "@/core/threads/hooks"; } from "@/core/threads/hooks";
import type { AgentThread, AgentThreadState } from "@/core/threads/types"; import type { AgentThread, AgentThreadState } from "@/core/threads/types";
import { pathOfThread, titleOfThread } from "@/core/threads/utils"; import {
channelSourceOfThread,
pathOfThread,
titleOfThread,
} from "@/core/threads/utils";
import { env } from "@/env"; import { env } from "@/env";
import { isIMEComposing } from "@/lib/ime"; import { isIMEComposing } from "@/lib/ime";
import { ThreadChannelIcon } from "./thread-channel-source";
export function RecentChatList() { export function RecentChatList() {
const { t } = useI18n(); const { t } = useI18n();
const router = useRouter(); const router = useRouter();
@@ -210,6 +216,7 @@ export function RecentChatList() {
<div className="flex w-full flex-col gap-1"> <div className="flex w-full flex-col gap-1">
{threads.map((thread) => { {threads.map((thread) => {
const isActive = pathOfThread(thread) === pathname; const isActive = pathOfThread(thread) === pathname;
const channelSource = channelSourceOfThread(thread);
return ( return (
<SidebarMenuItem <SidebarMenuItem
key={thread.thread_id} key={thread.thread_id}
@@ -218,10 +225,23 @@ export function RecentChatList() {
<SidebarMenuButton isActive={isActive} asChild> <SidebarMenuButton isActive={isActive} asChild>
<div> <div>
<Link <Link
className="text-muted-foreground block w-full whitespace-nowrap group-hover/side-menu-item:overflow-hidden" className="text-muted-foreground flex min-w-0 items-center gap-1.5 pr-7 whitespace-nowrap group-hover/side-menu-item:overflow-hidden"
href={pathOfThread(thread)} href={pathOfThread(thread)}
> >
{titleOfThread(thread)} <ThreadChannelIcon source={channelSource} />
<span className="min-w-0 truncate">
{titleOfThread(thread)}
</span>
{channelSource && (
<span
className="bg-muted text-muted-foreground ml-auto inline-flex h-5 max-w-14 shrink-0 items-center rounded-md px-1.5 text-[10px] font-medium"
title={`${channelSource.label} channel`}
>
<span className="truncate">
{channelSource.label}
</span>
</span>
)}
</Link> </Link>
{env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY !== "true" && ( {env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY !== "true" && (
<DropdownMenu> <DropdownMenu>
@@ -0,0 +1,377 @@
"use client";
import {
AlertCircleIcon,
CheckCircle2Icon,
LoaderCircleIcon,
PlugIcon,
UnplugIcon,
} from "lucide-react";
import { useState } from "react";
import { toast } from "sonner";
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import {
Item,
ItemActions,
ItemContent,
ItemDescription,
ItemMedia,
ItemTitle,
} from "@/components/ui/item";
import {
useConfigureChannelProvider,
useChannelConnections,
useChannelProviders,
useConnectChannelProvider,
useDisconnectChannelProvider,
} from "@/core/channels/hooks";
import {
closeConnectWindow,
openConnectUrl,
prepareConnectWindow,
} from "@/core/channels/open-connect-url";
import {
providerCanConnect,
providerCanEditRuntimeConfig,
providerNeedsRuntimeConfig,
} from "@/core/channels/provider-state";
import type { ChannelConnection, ChannelProvider } from "@/core/channels/types";
import { useI18n } from "@/core/i18n/hooks";
import { cn } from "@/lib/utils";
import { ChannelProviderIcon } from "../channels/channel-provider-icon";
import { ChannelRuntimeConfigDialog } from "../channels/channel-runtime-config-dialog";
import { SettingsSection } from "./settings-section";
function getProviderDescription(
provider: ChannelProvider,
descriptions: Record<string, string>,
): string {
return descriptions[provider.provider] ?? provider.display_name;
}
function getConnectionLabel(connection: ChannelConnection): string | null {
const account = connection.external_account_name;
const workspace = connection.workspace_name;
if (account && workspace) {
return `${account} · ${workspace}`;
}
return account ?? workspace ?? connection.external_account_id ?? null;
}
function getStatusLabel(
provider: ChannelProvider,
connection: ChannelConnection | undefined,
t: ReturnType<typeof useI18n>["t"],
): string {
if (!provider.enabled) {
return t.channels.disabled;
}
if (!provider.configured) {
return t.channels.unconfigured;
}
if (provider.unavailable_reason) {
return t.channels.unavailableShort;
}
const status = connection?.status ?? provider.connection_status;
if (status === "connected") {
return t.channels.connected;
}
if (status === "pending") {
return t.channels.pending;
}
if (status === "revoked") {
return t.channels.revoked;
}
return t.channels.notConnected;
}
function getProviderUnavailableReason(
provider: ChannelProvider,
t: ReturnType<typeof useI18n>["t"],
): string | undefined {
if (provider.unavailable_reason) {
return provider.unavailable_reason;
}
if (!provider.enabled) {
return t.channels.disabled;
}
if (!provider.configured) {
return t.channels.unconfigured;
}
return provider.unavailable_reason ?? undefined;
}
function ChannelProviderItem({
provider,
connection,
}: {
provider: ChannelProvider;
connection?: ChannelConnection;
}) {
const { t } = useI18n();
const connectMutation = useConnectChannelProvider();
const configureMutation = useConfigureChannelProvider();
const disconnectProviderMutation = useDisconnectChannelProvider();
const [setupOpen, setSetupOpen] = useState(false);
const isConnected =
connection?.status === "connected" ||
provider.connection_status === "connected";
const canEditRuntimeConfig = providerCanEditRuntimeConfig(provider);
const canConnect =
(provider.connectable ?? (provider.enabled && provider.configured)) &&
!isConnected;
const isConnecting =
(connectMutation.isPending &&
connectMutation.variables === provider.provider) ||
(configureMutation.isPending &&
configureMutation.variables?.provider === provider.provider);
const isDisconnecting =
disconnectProviderMutation.isPending &&
disconnectProviderMutation.variables === provider.provider;
const connectionLabel = connection ? getConnectionLabel(connection) : null;
const statusLabel = getStatusLabel(provider, connection, t);
const unavailableReason = getProviderUnavailableReason(provider, t);
const startConnect = (
connectProvider: ChannelProvider,
preparedWindow?: Window | null,
) => {
const connectWindow =
preparedWindow !== undefined
? preparedWindow
: connectProvider.auth_mode === "deep_link"
? prepareConnectWindow()
: null;
void connectMutation
.mutateAsync(connectProvider.provider)
.then((result) => {
if (result.url) {
openConnectUrl(result.url, connectWindow);
return;
}
closeConnectWindow(connectWindow);
toast.success(result.instruction);
})
.catch((error) => {
closeConnectWindow(connectWindow);
toast.error(
error instanceof Error ? error.message : t.channels.unavailable,
);
});
};
return (
<>
<Item variant="outline" className="w-full items-start">
<ItemMedia variant="icon" className="bg-background">
<ChannelProviderIcon
provider={provider.provider}
className="size-5"
/>
</ItemMedia>
<ItemContent className="min-w-0">
<ItemTitle className="w-full">
<span className="truncate">{provider.display_name}</span>
<Badge
variant={isConnected ? "default" : "outline"}
className={cn(!isConnected && "text-muted-foreground")}
>
{isConnected ? <CheckCircle2Icon /> : <AlertCircleIcon />}
{statusLabel}
</Badge>
</ItemTitle>
<ItemDescription className="line-clamp-none">
{getProviderDescription(provider, t.channels.descriptions)}
{connectionLabel
? ` ${t.channels.connectedAs(connectionLabel)}`
: ""}
{!isConnected && provider.unavailable_reason
? ` ${provider.unavailable_reason}`
: ""}
</ItemDescription>
</ItemContent>
<ItemActions className="ml-auto">
{isConnected ? (
<>
{canEditRuntimeConfig ? (
<Button
type="button"
variant="outline"
size="sm"
disabled={isConnecting || isDisconnecting}
onClick={() => setSetupOpen(true)}
>
{isConnecting ? (
<LoaderCircleIcon className="animate-spin" />
) : (
<PlugIcon />
)}
{t.channels.modify}
</Button>
) : null}
<Button
type="button"
variant="outline"
size="sm"
disabled={isDisconnecting}
onClick={() => {
void disconnectProviderMutation
.mutateAsync(provider.provider)
.then(() => {
toast.success(t.channels.revoked);
})
.catch((error) => {
toast.error(
error instanceof Error
? error.message
: t.channels.unavailable,
);
});
}}
>
{isDisconnecting ? (
<LoaderCircleIcon className="animate-spin" />
) : (
<UnplugIcon />
)}
{t.channels.disconnect}
</Button>
</>
) : (
<>
{provider.configured && canEditRuntimeConfig ? (
<Button
type="button"
variant="outline"
size="sm"
disabled={isConnecting || isDisconnecting}
onClick={() => setSetupOpen(true)}
>
{t.channels.modify}
</Button>
) : null}
<Button
type="button"
size="sm"
disabled={isConnecting}
title={unavailableReason}
onClick={() => {
if (providerNeedsRuntimeConfig(provider)) {
setSetupOpen(true);
return;
}
if (!canConnect) {
toast.error(unavailableReason ?? t.channels.unavailable);
return;
}
startConnect(provider);
}}
>
{isConnecting ? (
<LoaderCircleIcon className="animate-spin" />
) : (
<PlugIcon />
)}
{connection?.status === "revoked"
? t.channels.reconnect
: t.channels.connect}
</Button>
</>
)}
</ItemActions>
</Item>
<ChannelRuntimeConfigDialog
provider={provider}
open={setupOpen}
submitting={configureMutation.isPending}
onOpenChange={setSetupOpen}
onSubmit={(submitProvider, values) => {
const connectWindow =
submitProvider.auth_mode === "deep_link"
? prepareConnectWindow()
: null;
void configureMutation
.mutateAsync({ provider: submitProvider.provider, values })
.then((updated) => {
setSetupOpen(false);
if (providerCanConnect(updated)) {
startConnect(updated, connectWindow);
return;
}
closeConnectWindow(connectWindow);
toast.success(t.channels.connected);
})
.catch((error) => {
closeConnectWindow(connectWindow);
toast.error(
error instanceof Error ? error.message : t.channels.unavailable,
);
});
}}
/>
</>
);
}
export function ChannelsSettingsPage() {
const { t } = useI18n();
const {
enabled,
providers,
isLoading: providersLoading,
error: providersError,
} = useChannelProviders();
const {
connections,
isLoading: connectionsLoading,
error: connectionsError,
} = useChannelConnections();
const isLoading = providersLoading || connectionsLoading;
const error = providersError ?? connectionsError;
const visibleProviders = providers.filter((provider) => provider.enabled);
const connectionByProvider = new Map<string, ChannelConnection>();
for (const connection of connections) {
const existing = connectionByProvider.get(connection.provider);
if (!existing || connection.status === "connected") {
connectionByProvider.set(connection.provider, connection);
}
}
return (
<SettingsSection
title={t.settings.channels.title}
description={t.settings.channels.description}
>
{isLoading ? (
<div className="text-muted-foreground text-sm">{t.common.loading}</div>
) : error ? (
<div className="text-destructive text-sm">{t.channels.unavailable}</div>
) : !enabled ? (
<div className="text-muted-foreground text-sm">
{t.settings.channels.disabled}
</div>
) : visibleProviders.length === 0 ? (
<div className="text-muted-foreground text-sm">
{t.settings.channels.disabled}
</div>
) : (
<div className="flex w-full flex-col gap-4">
{visibleProviders.map((provider) => (
<ChannelProviderItem
key={provider.provider}
provider={provider}
connection={connectionByProvider.get(provider.provider)}
/>
))}
</div>
)}
</SettingsSection>
);
}
@@ -2,6 +2,7 @@
import { import {
BellIcon, BellIcon,
CableIcon,
InfoIcon, InfoIcon,
BrainIcon, BrainIcon,
PaletteIcon, PaletteIcon,
@@ -21,6 +22,7 @@ import { ScrollArea } from "@/components/ui/scroll-area";
import { AboutSettingsPage } from "@/components/workspace/settings/about-settings-page"; import { AboutSettingsPage } from "@/components/workspace/settings/about-settings-page";
import { AccountSettingsPage } from "@/components/workspace/settings/account-settings-page"; import { AccountSettingsPage } from "@/components/workspace/settings/account-settings-page";
import { AppearanceSettingsPage } from "@/components/workspace/settings/appearance-settings-page"; import { AppearanceSettingsPage } from "@/components/workspace/settings/appearance-settings-page";
import { ChannelsSettingsPage } from "@/components/workspace/settings/channels-settings-page";
import { MemorySettingsPage } from "@/components/workspace/settings/memory-settings-page"; import { MemorySettingsPage } from "@/components/workspace/settings/memory-settings-page";
import { NotificationSettingsPage } from "@/components/workspace/settings/notification-settings-page"; import { NotificationSettingsPage } from "@/components/workspace/settings/notification-settings-page";
import { SkillSettingsPage } from "@/components/workspace/settings/skill-settings-page"; import { SkillSettingsPage } from "@/components/workspace/settings/skill-settings-page";
@@ -31,6 +33,7 @@ import { cn } from "@/lib/utils";
type SettingsSection = type SettingsSection =
| "account" | "account"
| "appearance" | "appearance"
| "channels"
| "memory" | "memory"
| "tools" | "tools"
| "skills" | "skills"
@@ -72,6 +75,11 @@ export function SettingsDialog(props: SettingsDialogProps) {
label: t.settings.sections.notification, label: t.settings.sections.notification,
icon: BellIcon, icon: BellIcon,
}, },
{
id: "channels",
label: t.settings.sections.channels,
icon: CableIcon,
},
{ {
id: "memory", id: "memory",
label: t.settings.sections.memory, label: t.settings.sections.memory,
@@ -84,6 +92,7 @@ export function SettingsDialog(props: SettingsDialogProps) {
[ [
t.settings.sections.account, t.settings.sections.account,
t.settings.sections.appearance, t.settings.sections.appearance,
t.settings.sections.channels,
t.settings.sections.memory, t.settings.sections.memory,
t.settings.sections.tools, t.settings.sections.tools,
t.settings.sections.skills, t.settings.sections.skills,
@@ -143,6 +152,7 @@ export function SettingsDialog(props: SettingsDialogProps) {
/> />
)} )}
{activeSection === "notification" && <NotificationSettingsPage />} {activeSection === "notification" && <NotificationSettingsPage />}
{activeSection === "channels" && <ChannelsSettingsPage />}
{activeSection === "about" && <AboutSettingsPage />} {activeSection === "about" && <AboutSettingsPage />}
</div> </div>
</ScrollArea> </ScrollArea>
@@ -0,0 +1,56 @@
"use client";
import { ChannelProviderIcon } from "@/components/workspace/channels/channel-provider-icon";
import type { ChannelThreadSource } from "@/core/threads/utils";
import { cn } from "@/lib/utils";
type ThreadChannelIconProps = {
source: ChannelThreadSource | null;
className?: string;
};
export function ThreadChannelIcon({
source,
className,
}: ThreadChannelIconProps) {
if (!source) {
return null;
}
return (
<span
aria-label={`${source.label} channel`}
title={`${source.label} channel`}
className={cn("inline-flex shrink-0 items-center", className)}
>
<ChannelProviderIcon provider={source.provider} className="size-4" />
</span>
);
}
type ThreadChannelBadgeProps = {
source: ChannelThreadSource | null;
className?: string;
};
export function ThreadChannelBadge({
source,
className,
}: ThreadChannelBadgeProps) {
if (!source) {
return null;
}
return (
<span
className={cn(
"bg-muted text-muted-foreground inline-flex h-6 max-w-32 items-center gap-1 rounded-md px-2 text-xs font-medium",
className,
)}
title={`${source.label} channel`}
>
<ChannelProviderIcon provider={source.provider} className="size-3.5" />
<span className="truncate">{source.label}</span>
</span>
);
}
@@ -9,6 +9,7 @@ import {
useSidebar, useSidebar,
} from "@/components/ui/sidebar"; } from "@/components/ui/sidebar";
import { WorkspaceChannelsList } from "./channels/workspace-channels-list";
import { RecentChatList } from "./recent-chat-list"; import { RecentChatList } from "./recent-chat-list";
import { WorkspaceHeader } from "./workspace-header"; import { WorkspaceHeader } from "./workspace-header";
import { WorkspaceNavChatList } from "./workspace-nav-chat-list"; import { WorkspaceNavChatList } from "./workspace-nav-chat-list";
@@ -26,6 +27,7 @@ export function WorkspaceSidebar({
</SidebarHeader> </SidebarHeader>
<SidebarContent> <SidebarContent>
<WorkspaceNavChatList /> <WorkspaceNavChatList />
<WorkspaceChannelsList />
{isSidebarOpen && <RecentChatList />} {isSidebarOpen && <RecentChatList />}
</SidebarContent> </SidebarContent>
<SidebarFooter> <SidebarFooter>
+2 -2
View File
@@ -1,8 +1,8 @@
import type { User } from "./types"; import type { User } from "./types";
export const AUTH_DISABLED_USER: User = { export const AUTH_DISABLED_USER: User = {
id: "e2e-user", id: "default",
email: "e2e@test.local", email: "default@test.local",
system_role: "admin", system_role: "admin",
needs_setup: false, needs_setup: false,
}; };
+117
View File
@@ -0,0 +1,117 @@
import { fetch } from "@/core/api/fetcher";
import { getBackendBaseURL } from "@/core/config";
import type {
ChannelConnectResponse,
ChannelConnection,
ChannelConnectionsResponse,
ChannelProviderId,
ChannelProvider,
ChannelProvidersResponse,
ChannelRuntimeConfigValues,
} from "./types";
function channelsUrl(path: string): string {
return `${getBackendBaseURL()}/api/channels${path}`;
}
async function throwChannelApiError(
response: Response,
fallback: string,
): Promise<never> {
const body = (await response.json().catch(() => ({}))) as {
detail?: unknown;
};
throw new Error(typeof body.detail === "string" ? body.detail : fallback);
}
export async function listChannelProviders(): Promise<ChannelProvidersResponse> {
const response = await fetch(channelsUrl("/providers"));
if (!response.ok) {
await throwChannelApiError(
response,
`Failed to load channel providers: ${response.statusText}`,
);
}
return response.json() as Promise<ChannelProvidersResponse>;
}
export async function listChannelConnections(): Promise<ChannelConnection[]> {
const response = await fetch(channelsUrl("/connections"));
if (!response.ok) {
await throwChannelApiError(
response,
`Failed to load channel connections: ${response.statusText}`,
);
}
const data = (await response.json()) as ChannelConnectionsResponse;
return data.connections;
}
export async function connectChannelProvider(
provider: ChannelProviderId,
): Promise<ChannelConnectResponse> {
const response = await fetch(
channelsUrl(`/${encodeURIComponent(provider)}/connect`),
{ method: "POST" },
);
if (!response.ok) {
await throwChannelApiError(
response,
`Failed to connect ${provider}: ${response.statusText}`,
);
}
return response.json() as Promise<ChannelConnectResponse>;
}
export async function configureChannelProvider(
provider: ChannelProviderId,
values: ChannelRuntimeConfigValues,
): Promise<ChannelProvider> {
const response = await fetch(
channelsUrl(`/${encodeURIComponent(provider)}/runtime-config`),
{
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ values }),
},
);
if (!response.ok) {
await throwChannelApiError(
response,
`Failed to configure ${provider}: ${response.statusText}`,
);
}
return response.json() as Promise<ChannelProvider>;
}
export async function disconnectChannelConnection(
connectionId: string,
): Promise<void> {
const response = await fetch(
channelsUrl(`/connections/${encodeURIComponent(connectionId)}`),
{ method: "DELETE" },
);
if (!response.ok) {
await throwChannelApiError(
response,
`Failed to disconnect channel: ${response.statusText}`,
);
}
}
export async function disconnectChannelProvider(
provider: ChannelProviderId,
): Promise<ChannelProvider> {
const response = await fetch(
channelsUrl(`/${encodeURIComponent(provider)}/runtime-config`),
{ method: "DELETE" },
);
if (!response.ok) {
await throwChannelApiError(
response,
`Failed to disconnect ${provider}: ${response.statusText}`,
);
}
return response.json() as Promise<ChannelProvider>;
}
+96
View File
@@ -0,0 +1,96 @@
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import {
configureChannelProvider,
connectChannelProvider,
disconnectChannelConnection,
disconnectChannelProvider,
listChannelConnections,
listChannelProviders,
} from "./api";
import type { ChannelProviderId, ChannelRuntimeConfigValues } from "./types";
export const channelProviderQueryKey = ["channelProviders"] as const;
export const channelConnectionsQueryKey = ["channelConnections"] as const;
export function useChannelProviders() {
const { data, isLoading, error } = useQuery({
queryKey: channelProviderQueryKey,
queryFn: () => listChannelProviders(),
});
return {
enabled: data?.enabled ?? false,
providers: data?.providers ?? [],
isLoading,
error,
};
}
export function useChannelConnections() {
const { data, isLoading, error } = useQuery({
queryKey: channelConnectionsQueryKey,
queryFn: () => listChannelConnections(),
});
return { connections: data ?? [], isLoading, error };
}
export function useConnectChannelProvider() {
const queryClient = useQueryClient();
return useMutation({
mutationFn: (provider: ChannelProviderId) =>
connectChannelProvider(provider),
onSuccess: () => {
void queryClient.invalidateQueries({ queryKey: channelProviderQueryKey });
void queryClient.invalidateQueries({
queryKey: channelConnectionsQueryKey,
});
},
});
}
export function useConfigureChannelProvider() {
const queryClient = useQueryClient();
return useMutation({
mutationFn: ({
provider,
values,
}: {
provider: ChannelProviderId;
values: ChannelRuntimeConfigValues;
}) => configureChannelProvider(provider, values),
onSuccess: () => {
void queryClient.invalidateQueries({ queryKey: channelProviderQueryKey });
void queryClient.invalidateQueries({
queryKey: channelConnectionsQueryKey,
});
},
});
}
export function useDisconnectChannelConnection() {
const queryClient = useQueryClient();
return useMutation({
mutationFn: (connectionId: string) =>
disconnectChannelConnection(connectionId),
onSuccess: () => {
void queryClient.invalidateQueries({ queryKey: channelProviderQueryKey });
void queryClient.invalidateQueries({
queryKey: channelConnectionsQueryKey,
});
},
});
}
export function useDisconnectChannelProvider() {
const queryClient = useQueryClient();
return useMutation({
mutationFn: (provider: ChannelProviderId) =>
disconnectChannelProvider(provider),
onSuccess: () => {
void queryClient.invalidateQueries({ queryKey: channelProviderQueryKey });
void queryClient.invalidateQueries({
queryKey: channelConnectionsQueryKey,
});
},
});
}
@@ -0,0 +1,27 @@
export type ChannelConnectWindow = Window | null;
export function prepareConnectWindow(): ChannelConnectWindow {
const opened = window.open("about:blank", "_blank");
if (opened) {
opened.opener = null;
}
return opened;
}
export function openConnectUrl(
url: string,
connectWindow: ChannelConnectWindow = prepareConnectWindow(),
) {
if (connectWindow && !connectWindow.closed) {
connectWindow.location.replace(url);
return;
}
window.location.assign(url);
}
export function closeConnectWindow(connectWindow: ChannelConnectWindow) {
if (connectWindow && !connectWindow.closed) {
connectWindow.close();
}
}
@@ -0,0 +1,22 @@
import type { ChannelProvider } from "./types";
export function providerCanConnect(provider: ChannelProvider): boolean {
return (
(provider.connectable ?? (provider.enabled && provider.configured)) &&
provider.connection_status !== "connected"
);
}
export function providerNeedsRuntimeConfig(provider: ChannelProvider): boolean {
return (
provider.enabled &&
!provider.configured &&
(provider.credential_fields?.length ?? 0) > 0
);
}
export function providerCanEditRuntimeConfig(
provider: ChannelProvider,
): boolean {
return provider.enabled && (provider.credential_fields?.length ?? 0) > 0;
}
+53
View File
@@ -0,0 +1,53 @@
export type ChannelProviderId = "telegram" | "slack" | "discord" | string;
export interface ChannelCredentialField {
name: string;
label: string;
type: string;
required: boolean;
}
export type ChannelRuntimeConfigValues = Record<string, string>;
export interface ChannelProvider {
provider: ChannelProviderId;
display_name: string;
enabled: boolean;
configured: boolean;
connectable?: boolean;
unavailable_reason?: string | null;
auth_mode: string;
connection_status: string;
credential_fields: ChannelCredentialField[];
credential_values?: ChannelRuntimeConfigValues;
}
export interface ChannelProvidersResponse {
enabled: boolean;
providers: ChannelProvider[];
}
export interface ChannelConnection {
id: string;
provider: ChannelProviderId;
status: string;
external_account_id?: string | null;
external_account_name?: string | null;
workspace_id?: string | null;
workspace_name?: string | null;
scopes: string[];
metadata: Record<string, unknown>;
}
export interface ChannelConnectionsResponse {
connections: ChannelConnection[];
}
export interface ChannelConnectResponse {
provider: ChannelProviderId;
mode: string;
url?: string | null;
code: string;
instruction: string;
expires_in: number;
}
+42
View File
@@ -170,6 +170,7 @@ export const enUS: Translations = {
sidebar: { sidebar: {
newChat: "New chat", newChat: "New chat",
chats: "Chats", chats: "Chats",
channels: "Channels",
recentChats: "Recent chats", recentChats: "Recent chats",
demoChats: "Demo chats", demoChats: "Demo chats",
agents: "Agents", agents: "Agents",
@@ -259,6 +260,39 @@ export const enUS: Translations = {
loadOlderChats: "Load older chats", loadOlderChats: "Load older chats",
}, },
// Channels
channels: {
title: "Channels",
connect: "Connect",
modify: "Modify",
reconnect: "Reconnect",
disconnect: "Disconnect",
connected: "Connected",
notConnected: "Not connected",
pending: "Pending",
revoked: "Disconnected",
disabled: "Disabled",
unconfigured: "Not configured",
unavailable: "Channel connections are unavailable right now.",
unavailableShort: "Unavailable",
setupTitle: (name: string) => `Connect ${name}`,
setupEditTitle: (name: string) => `Modify ${name}`,
setupDescription:
"Enter the values needed by this server process. They are not written to config.yaml.",
saveAndConnect: "Save and connect",
saveChanges: "Save changes",
descriptions: {
telegram: "Telegram direct messages through your DeerFlow bot.",
slack: "Slack workspace messages and mentions.",
discord: "Discord server messages through your DeerFlow bot.",
feishu: "Feishu and Lark messages through your DeerFlow app.",
dingtalk: "DingTalk Stream Push messages through your DeerFlow bot.",
wechat: "WeChat iLink messages through your DeerFlow bot.",
wecom: "WeCom messages through your DeerFlow AI bot.",
},
connectedAs: (name: string) => `Connected as ${name}.`,
},
// Page titles (document title) // Page titles (document title)
pages: { pages: {
appName: "DeerFlow", appName: "DeerFlow",
@@ -359,6 +393,7 @@ export const enUS: Translations = {
sections: { sections: {
account: "Account", account: "Account",
appearance: "Appearance", appearance: "Appearance",
channels: "Channels",
memory: "Memory", memory: "Memory",
tools: "Tools", tools: "Tools",
skills: "Skills", skills: "Skills",
@@ -461,6 +496,13 @@ export const enUS: Translations = {
title: "Tools", title: "Tools",
description: "Manage the configuration and enabled status of MCP tools.", description: "Manage the configuration and enabled status of MCP tools.",
}, },
channels: {
title: "Channels",
description:
"Connect IM accounts that can send messages to DeerFlow from outside the browser.",
disabled:
"Channel connections are not enabled on this server. Ask an administrator to enable channel_connections.",
},
skills: { skills: {
title: "Agent Skills", title: "Agent Skills",
description: description:
+31
View File
@@ -117,6 +117,7 @@ export interface Translations {
chats: string; chats: string;
demoChats: string; demoChats: string;
agents: string; agents: string;
channels: string;
}; };
// Agents // Agents
@@ -190,6 +191,30 @@ export interface Translations {
loadOlderChats: string; loadOlderChats: string;
}; };
// Channels
channels: {
title: string;
connect: string;
modify: string;
reconnect: string;
disconnect: string;
connected: string;
notConnected: string;
pending: string;
revoked: string;
disabled: string;
unconfigured: string;
unavailable: string;
unavailableShort: string;
setupTitle: (name: string) => string;
setupEditTitle: (name: string) => string;
setupDescription: string;
saveAndConnect: string;
saveChanges: string;
descriptions: Record<string, string>;
connectedAs: (name: string) => string;
};
// Page titles (document title) // Page titles (document title)
pages: { pages: {
appName: string; appName: string;
@@ -286,6 +311,7 @@ export interface Translations {
sections: { sections: {
account: string; account: string;
appearance: string; appearance: string;
channels: string;
memory: string; memory: string;
tools: string; tools: string;
skills: string; skills: string;
@@ -381,6 +407,11 @@ export interface Translations {
title: string; title: string;
description: string; description: string;
}; };
channels: {
title: string;
description: string;
disabled: string;
};
skills: { skills: {
title: string; title: string;
description: string; description: string;
+41
View File
@@ -164,6 +164,7 @@ export const zhCN: Translations = {
sidebar: { sidebar: {
newChat: "新对话", newChat: "新对话",
chats: "对话", chats: "对话",
channels: "渠道",
recentChats: "最近的对话", recentChats: "最近的对话",
demoChats: "演示对话", demoChats: "演示对话",
agents: "智能体", agents: "智能体",
@@ -247,6 +248,39 @@ export const zhCN: Translations = {
loadOlderChats: "加载更早的对话", loadOlderChats: "加载更早的对话",
}, },
// Channels
channels: {
title: "渠道",
connect: "连接",
modify: "修改",
reconnect: "重新连接",
disconnect: "断开连接",
connected: "已连接",
notConnected: "未连接",
pending: "待完成",
revoked: "已断开",
disabled: "已停用",
unconfigured: "未配置",
unavailable: "当前无法使用渠道连接。",
unavailableShort: "不可用",
setupTitle: (name: string) => `连接 ${name}`,
setupEditTitle: (name: string) => `修改 ${name}`,
setupDescription:
"填写当前服务进程需要的配置值。这些内容不会写入 config.yaml。",
saveAndConnect: "保存并连接",
saveChanges: "保存修改",
descriptions: {
telegram: "通过 DeerFlow Bot 接收 Telegram 私聊消息。",
slack: "接收 Slack 工作区消息和提及。",
discord: "通过 DeerFlow Bot 接收 Discord 服务器消息。",
feishu: "通过 DeerFlow 应用接收飞书和 Lark 消息。",
dingtalk: "通过 DeerFlow Bot 接收钉钉 Stream Push 消息。",
wechat: "通过 DeerFlow Bot 接收微信 iLink 消息。",
wecom: "通过 DeerFlow AI Bot 接收企业微信消息。",
},
connectedAs: (name: string) => `已连接为 ${name}`,
},
// Page titles (document title) // Page titles (document title)
pages: { pages: {
appName: "DeerFlow", appName: "DeerFlow",
@@ -343,6 +377,7 @@ export const zhCN: Translations = {
sections: { sections: {
account: "账号", account: "账号",
appearance: "外观", appearance: "外观",
channels: "渠道",
memory: "记忆", memory: "记忆",
tools: "工具", tools: "工具",
skills: "技能", skills: "技能",
@@ -442,6 +477,12 @@ export const zhCN: Translations = {
title: "工具", title: "工具",
description: "管理 MCP 工具的配置和启用状态。", description: "管理 MCP 工具的配置和启用状态。",
}, },
channels: {
title: "渠道",
description: "连接可在浏览器外向 DeerFlow 发送消息的即时通讯账号。",
disabled:
"当前服务器未启用渠道连接。请联系管理员开启 channel_connections。",
},
skills: { skills: {
title: "技能", title: "技能",
description: "管理 Agent Skill 配置和启用状态。", description: "管理 Agent Skill 配置和启用状态。",
+7 -60
View File
@@ -26,6 +26,11 @@ import type { UploadedFileInfo } from "../uploads";
import { promptInputFilePartToFile, uploadFiles } from "../uploads"; import { promptInputFilePartToFile, uploadFiles } from "../uploads";
import { fetchThreadTokenUsage } from "./api"; import { fetchThreadTokenUsage } from "./api";
import {
buildThreadsSearchQueryOptions,
DEFAULT_THREAD_SEARCH_PARAMS,
type ThreadSearchParams,
} from "./thread-search-query";
import { threadTokenUsageQueryKey } from "./token-usage"; import { threadTokenUsageQueryKey } from "./token-usage";
import type { import type {
AgentThread, AgentThread,
@@ -1201,69 +1206,11 @@ export function useThreadHistory(
} }
export function useThreads( export function useThreads(
params: Parameters<ThreadsClient["search"]>[0] = { params: ThreadSearchParams = DEFAULT_THREAD_SEARCH_PARAMS,
limit: 50,
sortBy: "updated_at",
sortOrder: "desc",
select: ["thread_id", "updated_at", "values", "metadata"],
},
) { ) {
const apiClient = getAPIClient(); const apiClient = getAPIClient();
return useQuery<AgentThread[]>({ return useQuery<AgentThread[]>({
queryKey: ["threads", "search", params], ...buildThreadsSearchQueryOptions(apiClient, params),
queryFn: async () => {
const maxResults = params.limit;
const initialOffset = params.offset ?? 0;
const DEFAULT_PAGE_SIZE = 50;
// Preserve prior semantics: if a non-positive limit is explicitly provided,
// delegate to a single search call with the original parameters.
if (maxResults !== undefined && maxResults <= 0) {
const response =
await apiClient.threads.search<AgentThreadState>(params);
return response as AgentThread[];
}
const pageSize =
typeof maxResults === "number" && maxResults > 0
? Math.min(DEFAULT_PAGE_SIZE, maxResults)
: DEFAULT_PAGE_SIZE;
const threads: AgentThread[] = [];
let offset = initialOffset;
while (true) {
if (typeof maxResults === "number" && threads.length >= maxResults) {
break;
}
const currentLimit =
typeof maxResults === "number"
? Math.min(pageSize, maxResults - threads.length)
: pageSize;
if (typeof maxResults === "number" && currentLimit <= 0) {
break;
}
const response = (await apiClient.threads.search<AgentThreadState>({
...params,
limit: currentLimit,
offset,
})) as AgentThread[];
threads.push(...response);
if (response.length < currentLimit) {
break;
}
offset += response.length;
}
return threads;
},
refetchOnWindowFocus: false,
}); });
} }
@@ -0,0 +1,86 @@
import type { ThreadsClient } from "@langchain/langgraph-sdk/client";
import type { AgentThread, AgentThreadState } from "./types";
type ThreadsSearchClient = {
threads: {
search: ThreadsClient["search"];
};
};
export type ThreadSearchParams = NonNullable<
Parameters<ThreadsClient["search"]>[0]
>;
export const DEFAULT_THREAD_SEARCH_PARAMS: ThreadSearchParams = {
limit: 50,
sortBy: "updated_at",
sortOrder: "desc",
select: ["thread_id", "updated_at", "values", "metadata"],
};
export const THREAD_SEARCH_REFETCH_INTERVAL_MS = 5000;
export function buildThreadsSearchQueryOptions(
apiClient: ThreadsSearchClient,
params: ThreadSearchParams = DEFAULT_THREAD_SEARCH_PARAMS,
) {
return {
queryKey: ["threads", "search", params],
queryFn: async () => {
const maxResults = params.limit;
const initialOffset = params.offset ?? 0;
const DEFAULT_PAGE_SIZE = 50;
// Preserve prior semantics: if a non-positive limit is explicitly provided,
// delegate to a single search call with the original parameters.
if (maxResults !== undefined && maxResults <= 0) {
const response =
await apiClient.threads.search<AgentThreadState>(params);
return response as AgentThread[];
}
const pageSize =
typeof maxResults === "number" && maxResults > 0
? Math.min(DEFAULT_PAGE_SIZE, maxResults)
: DEFAULT_PAGE_SIZE;
const threads: AgentThread[] = [];
let offset = initialOffset;
while (true) {
if (typeof maxResults === "number" && threads.length >= maxResults) {
break;
}
const currentLimit =
typeof maxResults === "number"
? Math.min(pageSize, maxResults - threads.length)
: pageSize;
if (typeof maxResults === "number" && currentLimit <= 0) {
break;
}
const response = (await apiClient.threads.search<AgentThreadState>({
...params,
limit: currentLimit,
offset,
})) as AgentThread[];
threads.push(...response);
if (response.length < currentLimit) {
break;
}
offset += response.length;
}
return threads;
},
refetchInterval: THREAD_SEARCH_REFETCH_INTERVAL_MS,
refetchIntervalInBackground: false,
refetchOnWindowFocus: false,
};
}
+45
View File
@@ -2,6 +2,12 @@ import type { Message } from "@langchain/langgraph-sdk";
import type { AgentThread, AgentThreadContext } from "./types"; import type { AgentThread, AgentThreadContext } from "./types";
export type ChannelThreadSource = {
type: "im_channel";
provider: string;
label: string;
};
type ThreadRouteTarget = type ThreadRouteTarget =
| string | string
| { | {
@@ -49,3 +55,42 @@ export function textOfMessage(message: Message) {
export function titleOfThread(thread: AgentThread) { export function titleOfThread(thread: AgentThread) {
return thread.values?.title ?? "Untitled"; return thread.values?.title ?? "Untitled";
} }
const CHANNEL_PROVIDER_LABELS: Record<string, string> = {
dingtalk: "DingTalk",
discord: "Discord",
feishu: "Feishu",
slack: "Slack",
telegram: "Telegram",
wechat: "WeChat",
wecom: "WeCom",
};
function labelOfChannelProvider(provider: string) {
return CHANNEL_PROVIDER_LABELS[provider] ?? provider;
}
export function channelSourceOfThread(
thread: Pick<AgentThread, "metadata">,
): ChannelThreadSource | null {
const source = thread.metadata?.channel_source;
if (!source || typeof source !== "object" || Array.isArray(source)) {
return null;
}
if (Reflect.get(source, "type") !== "im_channel") {
return null;
}
const provider = Reflect.get(source, "provider");
if (typeof provider !== "string" || provider.trim().length === 0) {
return null;
}
const normalizedProvider = provider.trim().toLowerCase();
return {
type: "im_channel",
provider: normalizedProvider,
label: labelOfChannelProvider(normalizedProvider),
};
}
+8 -2
View File
@@ -19,16 +19,22 @@ interface Shortcut {
export function useGlobalShortcuts(shortcuts: Shortcut[]) { export function useGlobalShortcuts(shortcuts: Shortcut[]) {
useEffect(() => { useEffect(() => {
function handleKeyDown(event: KeyboardEvent) { function handleKeyDown(event: KeyboardEvent) {
if (typeof event.key !== "string" || event.key.length === 0) {
return;
}
const meta = event.metaKey || event.ctrlKey; const meta = event.metaKey || event.ctrlKey;
const eventKey = event.key.toLowerCase();
for (const shortcut of shortcuts) { for (const shortcut of shortcuts) {
const shortcutKey = shortcut.key.toLowerCase();
if ( if (
event.key.toLowerCase() === shortcut.key.toLowerCase() && eventKey === shortcutKey &&
meta === shortcut.meta && meta === shortcut.meta &&
(shortcut.shift ?? false) === event.shiftKey (shortcut.shift ?? false) === event.shiftKey
) { ) {
// Allow Cmd+K even in inputs (standard command palette behavior) // Allow Cmd+K even in inputs (standard command palette behavior)
if (shortcut.key !== "k") { if (shortcutKey !== "k") {
const target = event.target as HTMLElement; const target = event.target as HTMLElement;
const tag = target.tagName; const tag = target.tagName;
if ( if (
+452
View File
@@ -0,0 +1,452 @@
import { expect, test, type Page } from "@playwright/test";
import { mockLangGraphAPI } from "./utils/mock-api";
const channelProviders = [
["telegram", "Telegram", "deep_link"],
["slack", "Slack", "binding_code"],
["discord", "Discord", "binding_code"],
["feishu", "Feishu", "binding_code"],
["dingtalk", "DingTalk", "binding_code"],
["wechat", "WeChat", "binding_code"],
["wecom", "WeCom", "binding_code"],
] as const;
type MockChannelProvider = {
provider: string;
display_name: string;
enabled: boolean;
configured: boolean;
connectable: boolean;
auth_mode: string;
connection_status: string;
unavailable_reason?: string | null;
credential_fields?: Array<{
name: string;
label: string;
type: string;
required: boolean;
}>;
credential_values?: Record<string, string>;
};
function defaultProviders(): MockChannelProvider[] {
return channelProviders.map(([provider, displayName, authMode]) => ({
provider,
display_name: displayName,
enabled: true,
configured: true,
connectable: true,
auth_mode: authMode,
connection_status: "connected",
credential_fields: [
{
name: "token",
label: "Token",
type: "password",
required: true,
},
],
}));
}
function mockChannelsAPI(
page: Page,
providers: MockChannelProvider[] = defaultProviders(),
onSlackConnect?: () => void,
) {
void page.route("**/api/channels/providers", (route) => {
return route.fulfill({
status: 200,
contentType: "application/json",
body: JSON.stringify({
enabled: true,
providers,
}),
});
});
void page.route("**/api/channels/connections", (route) => {
return route.fulfill({
status: 200,
contentType: "application/json",
body: JSON.stringify({ connections: [] }),
});
});
void page.route("**/api/channels/slack/connect", (route) => {
onSlackConnect?.();
return route.fulfill({
status: 200,
contentType: "application/json",
body: JSON.stringify({
provider: "slack",
mode: "binding_code",
url: null,
code: "abc123",
instruction: "Send /connect abc123 to the DeerFlow Slack bot.",
expires_in: 600,
}),
});
});
}
test.describe("IM channels", () => {
test("sidebar and settings expose channel connections", async ({ page }) => {
mockLangGraphAPI(page);
mockChannelsAPI(page);
await page.goto("/workspace/chats/new");
const sidebar = page.locator("[data-sidebar='sidebar']");
await expect(sidebar.getByText("Channels")).toBeVisible({
timeout: 15_000,
});
await expect(sidebar.getByText("Telegram")).toBeVisible();
await expect(sidebar.getByText("Slack")).toBeVisible();
await expect(sidebar.getByText("Discord")).toBeVisible();
await expect(sidebar.getByText("Feishu")).toBeVisible();
await expect(sidebar.getByText("DingTalk")).toBeVisible();
await expect(sidebar.getByText("WeChat")).toBeVisible();
await expect(sidebar.getByText("WeCom")).toBeVisible();
await expect(
sidebar.getByRole("button", { name: "Connected" }),
).toHaveCount(7);
await sidebar.getByRole("button", { name: /Settings and more/ }).click();
await page.getByRole("menuitem", { name: "Settings" }).click();
await page.getByRole("button", { name: "Channels" }).click();
await expect(page.getByText("Telegram direct messages")).toBeVisible();
await expect(page.getByText("Slack workspace messages")).toBeVisible();
await expect(page.getByText("Discord server messages")).toBeVisible();
await expect(page.getByText("Feishu and Lark messages")).toBeVisible();
await expect(page.getByText("DingTalk Stream Push messages")).toBeVisible();
await expect(page.getByText("WeChat iLink messages")).toBeVisible();
await expect(page.getByText("WeCom messages")).toBeVisible();
const dialog = page.getByRole("dialog", { name: "Settings" });
await expect(dialog.getByRole("button", { name: "Modify" })).toHaveCount(7);
});
test("only enabled providers are shown and runtime setup stays editable", async ({
page,
}) => {
mockLangGraphAPI(page);
let slackConfigured = false;
let submittedValues: Record<string, string> | undefined;
void page.route("**/api/channels/providers", (route) => {
return route.fulfill({
status: 200,
contentType: "application/json",
body: JSON.stringify({
enabled: true,
providers: [
{
provider: "slack",
display_name: "Slack",
enabled: true,
configured: slackConfigured,
connectable: slackConfigured,
auth_mode: "binding_code",
connection_status: slackConfigured
? "connected"
: "not_connected",
credential_fields: [
{
name: "bot_token",
label: "Bot token",
type: "password",
required: true,
},
{
name: "app_token",
label: "App token",
type: "password",
required: true,
},
],
credential_values: slackConfigured
? {
bot_token: "********",
app_token: "********",
}
: {},
},
{
provider: "discord",
display_name: "Discord",
enabled: false,
configured: false,
connectable: false,
auth_mode: "binding_code",
connection_status: "not_connected",
credential_fields: [],
},
],
}),
});
});
void page.route("**/api/channels/connections", (route) => {
return route.fulfill({
status: 200,
contentType: "application/json",
body: JSON.stringify({ connections: [] }),
});
});
void page.route("**/api/channels/slack/runtime-config", async (route) => {
const body = route.request().postDataJSON() as {
values: Record<string, string>;
};
submittedValues = body.values;
slackConfigured = true;
return route.fulfill({
status: 200,
contentType: "application/json",
body: JSON.stringify({
provider: "slack",
display_name: "Slack",
enabled: true,
configured: true,
connectable: true,
auth_mode: "binding_code",
connection_status: "connected",
credential_fields: [],
credential_values: {},
}),
});
});
void page.route("**/api/channels/slack/connect", (route) => route.abort());
await page.goto("/workspace/chats/new");
const sidebar = page.locator("[data-sidebar='sidebar']");
await expect(sidebar.getByText("Slack")).toBeVisible({ timeout: 15_000 });
await expect(sidebar.getByText("Discord")).toBeHidden();
const connectButton = sidebar.getByRole("button", { name: "Connect" });
await expect(connectButton).toBeEnabled();
await connectButton.click();
const setupDialog = page.getByRole("dialog", { name: "Connect Slack" });
await expect(setupDialog).toBeVisible();
const botTokenInput = setupDialog.getByLabel("Bot token");
await expect(botTokenInput).toHaveAttribute("type", "text");
await expect(botTokenInput).toHaveAttribute("autocomplete", "off");
await expect(botTokenInput).toHaveAttribute("data-lpignore", "true");
await expect(botTokenInput).toHaveAttribute("data-1p-ignore", "true");
await expect(botTokenInput).toHaveCSS("-webkit-text-security", "disc");
await setupDialog.getByLabel("Bot token").fill("xoxb-ui");
await setupDialog.getByLabel("App token").fill("xapp-ui");
await setupDialog.getByRole("button", { name: "Save and connect" }).click();
await expect(setupDialog).toBeHidden();
await expect(
sidebar.getByRole("button", { name: "Connected" }),
).toBeVisible();
await sidebar.getByRole("button", { name: "Connected" }).click();
await expect(
page.getByRole("dialog", { name: "Modify Slack" }),
).toBeVisible();
await expect(page.getByLabel("Bot token")).toHaveValue("********");
await expect(page.getByLabel("App token")).toHaveValue("********");
expect(submittedValues).toEqual({
bot_token: "xoxb-ui",
app_token: "xapp-ui",
});
});
test("configured provider connects directly with a binding-code instruction", async ({
page,
}) => {
mockLangGraphAPI(page);
let slackConnectCalls = 0;
mockChannelsAPI(
page,
[
{
provider: "slack",
display_name: "Slack",
enabled: true,
configured: true,
connectable: true,
auth_mode: "binding_code",
connection_status: "not_connected",
credential_fields: [
{
name: "bot_token",
label: "Bot token",
type: "password",
required: true,
},
],
credential_values: { bot_token: "********" },
},
],
() => {
slackConnectCalls += 1;
},
);
await page.goto("/workspace/chats/new");
const sidebar = page.locator("[data-sidebar='sidebar']");
await expect(sidebar.getByText("Slack")).toBeVisible({ timeout: 15_000 });
await sidebar.getByRole("button", { name: "Connect" }).click();
await expect(
page.getByText("Send /connect abc123 to the DeerFlow Slack bot."),
).toBeVisible();
expect(slackConnectCalls).toBe(1);
});
test("runtime setup continues into the connect flow when a binding is still required", async ({
page,
}) => {
mockLangGraphAPI(page);
let slackConfigured = false;
let slackConnectCalls = 0;
void page.route("**/api/channels/providers", (route) => {
return route.fulfill({
status: 200,
contentType: "application/json",
body: JSON.stringify({
enabled: true,
providers: [
{
provider: "slack",
display_name: "Slack",
enabled: true,
configured: slackConfigured,
connectable: slackConfigured,
auth_mode: "binding_code",
connection_status: "not_connected",
credential_fields: [
{
name: "bot_token",
label: "Bot token",
type: "password",
required: true,
},
],
credential_values: {},
},
],
}),
});
});
void page.route("**/api/channels/connections", (route) => {
return route.fulfill({
status: 200,
contentType: "application/json",
body: JSON.stringify({ connections: [] }),
});
});
void page.route("**/api/channels/slack/runtime-config", (route) => {
slackConfigured = true;
return route.fulfill({
status: 200,
contentType: "application/json",
body: JSON.stringify({
provider: "slack",
display_name: "Slack",
enabled: true,
configured: true,
connectable: true,
auth_mode: "binding_code",
connection_status: "not_connected",
credential_fields: [],
credential_values: {},
}),
});
});
void page.route("**/api/channels/slack/connect", (route) => {
slackConnectCalls += 1;
return route.fulfill({
status: 200,
contentType: "application/json",
body: JSON.stringify({
provider: "slack",
mode: "binding_code",
url: null,
code: "abc123",
instruction: "Send /connect abc123 to the DeerFlow Slack bot.",
expires_in: 600,
}),
});
});
await page.goto("/workspace/chats/new");
const sidebar = page.locator("[data-sidebar='sidebar']");
await expect(sidebar.getByText("Slack")).toBeVisible({ timeout: 15_000 });
await sidebar.getByRole("button", { name: "Connect" }).click();
const setupDialog = page.getByRole("dialog", { name: "Connect Slack" });
await expect(setupDialog).toBeVisible();
await setupDialog.getByLabel("Bot token").fill("xoxb-ui");
await setupDialog.getByRole("button", { name: "Save and connect" }).click();
await expect(setupDialog).toBeHidden();
await expect(
page.getByText("Send /connect abc123 to the DeerFlow Slack bot."),
).toBeVisible();
expect(slackConnectCalls).toBe(1);
});
test("runtime setup dialog prefills editable credential values", async ({
page,
}) => {
mockLangGraphAPI(page);
mockChannelsAPI(page, [
{
provider: "feishu",
display_name: "Feishu",
enabled: true,
configured: true,
connectable: true,
auth_mode: "binding_code",
connection_status: "connected",
credential_fields: [
{
name: "app_id",
label: "App ID",
type: "text",
required: true,
},
{
name: "app_secret",
label: "App secret",
type: "password",
required: true,
},
],
credential_values: {
app_id: "cli_feishu_app",
app_secret: "********",
},
},
]);
await page.goto("/workspace/chats/new");
const sidebar = page.locator("[data-sidebar='sidebar']");
await expect(sidebar.getByText("Feishu")).toBeVisible({ timeout: 15_000 });
await sidebar.getByRole("button", { name: "Connected" }).click();
const setupDialog = page.getByRole("dialog", { name: "Modify Feishu" });
await expect(setupDialog).toBeVisible();
await expect(setupDialog.getByLabel("App ID")).toHaveValue(
"cli_feishu_app",
);
await expect(setupDialog.getByLabel("App secret")).toHaveValue("********");
});
});
+39
View File
@@ -258,4 +258,43 @@ test.describe("Thread history", () => {
}); });
await expect(main.getByText("Second conversation")).toBeVisible(); await expect(main.getByText("Second conversation")).toBeVisible();
}); });
test("IM channel threads show their source in thread lists", async ({
page,
}) => {
mockLangGraphAPI(page, {
threads: [
{
thread_id: MOCK_THREAD_ID,
title: "Feishu conversation",
updated_at: "2025-06-03T12:00:00Z",
metadata: {
channel_source: {
type: "im_channel",
provider: "feishu",
chat_id: "oc_mock",
},
},
},
],
});
await page.goto("/workspace/chats/new");
const sidebarThread = page.locator(
`a[href='/workspace/chats/${MOCK_THREAD_ID}']`,
);
await expect(sidebarThread).toBeVisible({ timeout: 15_000 });
await expect(sidebarThread.getByLabel("Feishu channel")).toBeVisible();
await page.goto("/workspace/chats");
const mainThread = page
.locator("main")
.locator(`a[href='/workspace/chats/${MOCK_THREAD_ID}']`);
await expect(mainThread.getByText("Feishu conversation")).toBeVisible({
timeout: 15_000,
});
await expect(mainThread.getByText("Feishu", { exact: true })).toBeVisible();
});
}); });
+5 -1
View File
@@ -25,6 +25,7 @@ export type MockThread = {
title?: string; title?: string;
updated_at?: string; updated_at?: string;
agent_name?: string; agent_name?: string;
metadata?: Record<string, unknown>;
messages?: unknown[]; messages?: unknown[];
artifacts?: string[]; artifacts?: string[];
}; };
@@ -90,7 +91,10 @@ export function mockLangGraphAPI(page: Page, options?: MockAPIOptions) {
thread_id: t.thread_id, thread_id: t.thread_id,
created_at: "2025-01-01T00:00:00Z", created_at: "2025-01-01T00:00:00Z",
updated_at: t.updated_at ?? "2025-01-01T00:00:00Z", updated_at: t.updated_at ?? "2025-01-01T00:00:00Z",
metadata: t.agent_name ? { agent_name: t.agent_name } : {}, metadata: {
...(t.metadata ?? {}),
...(t.agent_name ? { agent_name: t.agent_name } : {}),
},
status: "idle", status: "idle",
values: { title: t.title ?? "Untitled" }, values: { title: t.title ?? "Untitled" },
})); }));
@@ -0,0 +1,220 @@
import { beforeEach, describe, expect, test, vi } from "vitest";
vi.mock("@/core/api/fetcher", () => ({
fetch: vi.fn(),
}));
vi.mock("@/core/config", () => ({
getBackendBaseURL: () => "/backend",
}));
import { fetch as fetcher } from "@/core/api/fetcher";
import {
configureChannelProvider,
connectChannelProvider,
disconnectChannelConnection,
disconnectChannelProvider,
listChannelConnections,
listChannelProviders,
} from "@/core/channels/api";
const mockedFetch = vi.mocked(fetcher);
function jsonResponse(status: number, body: unknown): Response {
return new Response(JSON.stringify(body), {
status,
statusText: status >= 400 ? "Bad Request" : "OK",
headers: { "Content-Type": "application/json" },
});
}
beforeEach(() => {
mockedFetch.mockReset();
});
describe("channels api", () => {
test("loads provider catalog", async () => {
mockedFetch.mockResolvedValueOnce(
jsonResponse(200, {
enabled: true,
providers: [
{
provider: "telegram",
display_name: "Telegram",
enabled: true,
configured: true,
auth_mode: "deep_link",
connection_status: "not_connected",
credential_values: {
bot_token: "********",
bot_username: "deerflow_bot",
},
},
],
}),
);
await expect(listChannelProviders()).resolves.toMatchObject({
enabled: true,
providers: [
{
provider: "telegram",
display_name: "Telegram",
credential_values: {
bot_token: "********",
bot_username: "deerflow_bot",
},
},
],
});
expect(mockedFetch).toHaveBeenCalledWith("/backend/api/channels/providers");
});
test("loads current user's connections", async () => {
mockedFetch.mockResolvedValueOnce(
jsonResponse(200, {
connections: [
{
id: "connection-1",
provider: "telegram",
status: "connected",
external_account_name: "Alice",
scopes: [],
metadata: {},
},
],
}),
);
await expect(listChannelConnections()).resolves.toMatchObject([
{ id: "connection-1", provider: "telegram", status: "connected" },
]);
expect(mockedFetch).toHaveBeenCalledWith(
"/backend/api/channels/connections",
);
});
test("starts a provider connection flow", async () => {
mockedFetch.mockResolvedValueOnce(
jsonResponse(200, {
provider: "telegram",
mode: "deep_link",
url: "https://t.me/deerflow_bot?start=state",
code: "state",
instruction: "Send /start state to the DeerFlow Telegram bot.",
expires_in: 600,
}),
);
await expect(connectChannelProvider("telegram")).resolves.toMatchObject({
provider: "telegram",
url: "https://t.me/deerflow_bot?start=state",
instruction: "Send /start state to the DeerFlow Telegram bot.",
});
expect(mockedFetch).toHaveBeenCalledWith(
"/backend/api/channels/telegram/connect",
{ method: "POST" },
);
});
test("starts a binding-code connection flow", async () => {
mockedFetch.mockResolvedValueOnce(
jsonResponse(200, {
provider: "slack",
mode: "binding_code",
url: null,
code: "abc123",
instruction: "Send /connect abc123 to the DeerFlow Slack bot.",
expires_in: 600,
}),
);
await expect(connectChannelProvider("slack")).resolves.toMatchObject({
provider: "slack",
url: null,
code: "abc123",
instruction: "Send /connect abc123 to the DeerFlow Slack bot.",
});
});
test("submits runtime provider configuration", async () => {
mockedFetch.mockResolvedValueOnce(
jsonResponse(200, {
provider: "slack",
display_name: "Slack",
enabled: true,
configured: true,
connectable: true,
auth_mode: "binding_code",
connection_status: "not_connected",
}),
);
await expect(
configureChannelProvider("slack", {
bot_token: "xoxb-ui",
app_token: "xapp-ui",
}),
).resolves.toMatchObject({
provider: "slack",
configured: true,
connectable: true,
});
expect(mockedFetch).toHaveBeenCalledWith(
"/backend/api/channels/slack/runtime-config",
{
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
values: { bot_token: "xoxb-ui", app_token: "xapp-ui" },
}),
},
);
});
test("disconnects a channel connection", async () => {
mockedFetch.mockResolvedValueOnce(new Response(null, { status: 204 }));
await expect(
disconnectChannelConnection("connection-1"),
).resolves.toBeUndefined();
expect(mockedFetch).toHaveBeenCalledWith(
"/backend/api/channels/connections/connection-1",
{ method: "DELETE" },
);
});
test("disconnects provider runtime configuration", async () => {
mockedFetch.mockResolvedValueOnce(
jsonResponse(200, {
provider: "slack",
display_name: "Slack",
enabled: true,
configured: false,
connectable: false,
auth_mode: "binding_code",
connection_status: "not_connected",
}),
);
await expect(disconnectChannelProvider("slack")).resolves.toMatchObject({
provider: "slack",
configured: false,
connection_status: "not_connected",
});
expect(mockedFetch).toHaveBeenCalledWith(
"/backend/api/channels/slack/runtime-config",
{ method: "DELETE" },
);
});
test("uses backend detail for failed requests", async () => {
mockedFetch.mockResolvedValueOnce(
jsonResponse(400, { detail: "Channel provider is not configured" }),
);
await expect(connectChannelProvider("slack")).rejects.toThrow(
"Channel provider is not configured",
);
});
});
@@ -0,0 +1,86 @@
import { afterEach, describe, expect, test, vi } from "vitest";
import {
closeConnectWindow,
openConnectUrl,
prepareConnectWindow,
} from "@/core/channels/open-connect-url";
type PopupStub = {
closed: boolean;
close: ReturnType<typeof vi.fn>;
location: {
replace: ReturnType<typeof vi.fn>;
};
opener: unknown;
};
function stubWindow(openResult: PopupStub | null) {
const assign = vi.fn();
const open = vi.fn(() => openResult);
vi.stubGlobal("window", {
open,
location: { assign },
});
return { assign, open };
}
function makePopup(): PopupStub {
return {
closed: false,
close: vi.fn(),
location: { replace: vi.fn() },
opener: {},
};
}
afterEach(() => {
vi.unstubAllGlobals();
});
describe("channel connect window helpers", () => {
test("opens a blank tab synchronously and detaches opener", () => {
const popup = makePopup();
const { open } = stubWindow(popup);
const prepared = prepareConnectWindow();
expect(open).toHaveBeenCalledWith("about:blank", "_blank");
expect(prepared).toBe(popup);
expect(popup.opener).toBeNull();
});
test("navigates a prepared popup without opening another window", () => {
const popup = makePopup();
const { assign, open } = stubWindow(null);
openConnectUrl(
"https://t.me/deerflow_bot?start=state",
popup as unknown as Window,
);
expect(open).not.toHaveBeenCalled();
expect(assign).not.toHaveBeenCalled();
expect(popup.location.replace).toHaveBeenCalledWith(
"https://t.me/deerflow_bot?start=state",
);
});
test("falls back to current-window navigation when no popup is available", () => {
const { assign } = stubWindow(null);
openConnectUrl("https://t.me/deerflow_bot?start=state");
expect(assign).toHaveBeenCalledWith(
"https://t.me/deerflow_bot?start=state",
);
});
test("closes a prepared popup on connect failure", () => {
const popup = makePopup();
closeConnectWindow(popup as unknown as Window);
expect(popup.close).toHaveBeenCalled();
});
});
@@ -0,0 +1,89 @@
import { describe, expect, it } from "vitest";
import {
providerCanConnect,
providerCanEditRuntimeConfig,
providerNeedsRuntimeConfig,
} from "@/core/channels/provider-state";
import type { ChannelProvider } from "@/core/channels/types";
function makeProvider(overrides: Partial<ChannelProvider>): ChannelProvider {
return {
provider: "slack",
display_name: "Slack",
enabled: true,
configured: true,
connectable: true,
auth_mode: "binding_code",
connection_status: "not_connected",
credential_fields: [
{
name: "bot_token",
label: "Bot token",
type: "password",
required: true,
},
],
...overrides,
};
}
describe("providerCanConnect", () => {
it("allows connecting a configured, not yet connected provider", () => {
expect(providerCanConnect(makeProvider({}))).toBe(true);
});
it("rejects an already connected provider", () => {
expect(
providerCanConnect(makeProvider({ connection_status: "connected" })),
).toBe(false);
});
it("rejects a non-connectable provider", () => {
expect(providerCanConnect(makeProvider({ connectable: false }))).toBe(
false,
);
});
it("falls back to enabled+configured when connectable is missing", () => {
expect(providerCanConnect(makeProvider({ connectable: undefined }))).toBe(
true,
);
expect(
providerCanConnect(
makeProvider({ connectable: undefined, configured: false }),
),
).toBe(false);
});
});
describe("providerNeedsRuntimeConfig", () => {
it("requires setup only when enabled and unconfigured with fields", () => {
expect(
providerNeedsRuntimeConfig(makeProvider({ configured: false })),
).toBe(true);
expect(providerNeedsRuntimeConfig(makeProvider({}))).toBe(false);
expect(
providerNeedsRuntimeConfig(
makeProvider({ configured: false, enabled: false }),
),
).toBe(false);
expect(
providerNeedsRuntimeConfig(
makeProvider({ configured: false, credential_fields: [] }),
),
).toBe(false);
});
});
describe("providerCanEditRuntimeConfig", () => {
it("is editable whenever enabled with credential fields", () => {
expect(providerCanEditRuntimeConfig(makeProvider({}))).toBe(true);
expect(providerCanEditRuntimeConfig(makeProvider({ enabled: false }))).toBe(
false,
);
expect(
providerCanEditRuntimeConfig(makeProvider({ credential_fields: [] })),
).toBe(false);
});
});
@@ -0,0 +1,19 @@
import { expect, test, vi } from "vitest";
import {
buildThreadsSearchQueryOptions,
DEFAULT_THREAD_SEARCH_PARAMS,
THREAD_SEARCH_REFETCH_INTERVAL_MS,
} from "@/core/threads/thread-search-query";
test("thread search query refreshes so IM-created sessions appear in the sidebar", () => {
const search = vi.fn();
const options = buildThreadsSearchQueryOptions(
{ threads: { search } },
DEFAULT_THREAD_SEARCH_PARAMS,
);
expect(options.refetchInterval).toBe(THREAD_SEARCH_REFETCH_INTERVAL_MS);
expect(options.refetchIntervalInBackground).toBe(false);
expect(options.refetchOnWindowFocus).toBe(false);
});
+38 -1
View File
@@ -1,6 +1,6 @@
import { expect, test } from "vitest"; import { expect, test } from "vitest";
import { pathOfThread } from "@/core/threads/utils"; import { channelSourceOfThread, pathOfThread } from "@/core/threads/utils";
test("uses standard chat route when thread has no agent context", () => { test("uses standard chat route when thread has no agent context", () => {
expect(pathOfThread("thread-123")).toBe("/workspace/chats/thread-123"); expect(pathOfThread("thread-123")).toBe("/workspace/chats/thread-123");
@@ -44,3 +44,40 @@ test("prefers context.agent_name over metadata.agent_name", () => {
}), }),
).toBe("/workspace/agents/from-context/chats/thread-789"); ).toBe("/workspace/agents/from-context/chats/thread-789");
}); });
test("reads IM channel source metadata", () => {
expect(
channelSourceOfThread({
metadata: {
channel_source: {
type: "im_channel",
provider: "feishu",
chat_id: "oc_123",
},
},
}),
).toEqual({
type: "im_channel",
provider: "feishu",
label: "Feishu",
});
});
test("ignores threads without valid IM channel source metadata", () => {
expect(channelSourceOfThread({ metadata: {} })).toBeNull();
expect(
channelSourceOfThread({
metadata: { channel_source: { provider: "" } },
}),
).toBeNull();
expect(
channelSourceOfThread({
metadata: {
channel_source: {
type: "other",
provider: "feishu",
},
},
}),
).toBeNull();
});
@@ -0,0 +1,61 @@
import { afterEach, describe, expect, test, vi } from "vitest";
type KeydownHandler = (event: KeyboardEvent) => void;
async function loadHookWithCapturedHandler() {
let cleanup: (() => void) | undefined;
let keydownHandler: KeydownHandler | undefined;
const addEventListener = vi.fn(
(type: string, listener: EventListenerOrEventListenerObject) => {
if (type === "keydown" && typeof listener === "function") {
keydownHandler = listener as KeydownHandler;
}
},
);
const removeEventListener = vi.fn();
vi.resetModules();
vi.doMock("react", () => ({
useEffect: (effect: () => void | (() => void)) => {
const result = effect();
cleanup = typeof result === "function" ? result : undefined;
},
}));
vi.stubGlobal("window", { addEventListener, removeEventListener });
const { useGlobalShortcuts } = await import("@/hooks/use-global-shortcuts");
return {
cleanup: () => cleanup?.(),
getKeydownHandler: () => keydownHandler,
useGlobalShortcuts,
};
}
afterEach(() => {
vi.doUnmock("react");
vi.unstubAllGlobals();
vi.resetModules();
});
describe("useGlobalShortcuts", () => {
test("ignores keydown events without a key", async () => {
const action = vi.fn();
const { getKeydownHandler, useGlobalShortcuts } =
await loadHookWithCapturedHandler();
useGlobalShortcuts([{ key: "k", meta: true, action }]);
const keydownHandler = getKeydownHandler();
expect(keydownHandler).toBeDefined();
expect(() =>
keydownHandler?.({
ctrlKey: false,
metaKey: true,
shiftKey: false,
} as KeyboardEvent),
).not.toThrow();
expect(action).not.toHaveBeenCalled();
});
});
+10 -1
View File
@@ -58,7 +58,7 @@ def main() -> int:
return 0 return 0
print() print()
total_steps = 4 total_steps = 5
from wizard.steps.llm import run_llm_step from wizard.steps.llm import run_llm_step
@@ -76,6 +76,10 @@ def main() -> int:
execution = run_execution_step(f"Step 3/{total_steps}") execution = run_execution_step(f"Step 3/{total_steps}")
from wizard.steps.channels import run_channels_step
channels = run_channels_step(f"Step 4/{total_steps}")
print_header(f"Step {total_steps}/{total_steps} · Writing configuration") print_header(f"Step {total_steps}/{total_steps} · Writing configuration")
write_config_yaml( write_config_yaml(
@@ -97,6 +101,7 @@ def main() -> int:
allow_host_bash=execution.allow_host_bash, allow_host_bash=execution.allow_host_bash,
include_bash_tool=execution.include_bash_tool, include_bash_tool=execution.include_bash_tool,
include_write_tools=execution.include_write_tools, include_write_tools=execution.include_write_tools,
channel_connection_providers=channels.enabled_providers,
) )
print_success(f"Config written to: {config_path.relative_to(project_root)}") print_success(f"Config written to: {config_path.relative_to(project_root)}")
@@ -148,6 +153,10 @@ def main() -> int:
print(f" {green('')} File write: enabled") print(f" {green('')} File write: enabled")
else: else:
print(f" {'':>3} File write: disabled") print(f" {'':>3} File write: disabled")
if channels.enabled_providers:
print(f" {green('')} IM channels: {', '.join(channels.enabled_providers)}")
else:
print(f" {'':>3} IM channels: disabled")
print() print()
print("Next steps:") print("Next steps:")
print(f" {cyan('make install')} # Install dependencies (first time only)") print(f" {cyan('make install')} # Install dependencies (first time only)")
+46
View File
@@ -0,0 +1,46 @@
"""Step: browser-connectable IM channel enablement."""
from __future__ import annotations
from dataclasses import dataclass
from wizard.ui import ask_multi_choice, print_header, print_info, print_success
CHANNEL_CONNECTION_OPTIONS: tuple[tuple[str, str, str], ...] = (
("telegram", "Telegram", "direct messages through your DeerFlow bot"),
("slack", "Slack", "workspace messages and mentions"),
("discord", "Discord", "server messages through your DeerFlow bot"),
("feishu", "Feishu / Lark", "messages through your DeerFlow app"),
("dingtalk", "DingTalk", "Stream Push messages through your DeerFlow bot"),
("wechat", "WeChat", "iLink messages through your DeerFlow bot"),
("wecom", "WeCom", "messages through your DeerFlow AI bot"),
)
@dataclass
class ChannelConnectionsStepResult:
enabled_providers: list[str]
def run_channels_step(step_label: str = "Step 4/5") -> ChannelConnectionsStepResult:
print_header(f"{step_label} · IM Channels (optional)")
print_info("Choose which IM channels should appear in the DeerFlow sidebar and Settings.")
print_info("Credentials can be entered later from the browser with Connect or Modify.")
print()
options = [f"{display_name}{description}" for _, display_name, description in CHANNEL_CONNECTION_OPTIONS]
selected = ask_multi_choice(
"Enable channels (comma-separated numbers, 'all', or Enter for none)",
options,
default=[],
)
enabled_providers = [CHANNEL_CONNECTION_OPTIONS[idx][0] for idx in selected]
if enabled_providers:
display_names = [CHANNEL_CONNECTION_OPTIONS[idx][1] for idx in selected]
print_success(f"Enabled channels: {', '.join(display_names)}")
else:
print_info("No IM channels selected; channel connections will stay disabled.")
return ChannelConnectionsStepResult(enabled_providers=enabled_providers)
+43
View File
@@ -224,6 +224,49 @@ def ask_choice(prompt: str, options: list[str], default: int | None = None) -> i
return _ask_choice_with_numbers(prompt, options, default=default) return _ask_choice_with_numbers(prompt, options, default=default)
def ask_multi_choice(prompt: str, options: list[str], default: list[int] | None = None) -> list[int]:
"""Present a numbered multi-select menu and return 0-based indexes."""
has_default = default is not None
default_indexes = list(default or [])
for i, opt in enumerate(options, 1):
marker = f" {green('*')}" if has_default and i - 1 in default_indexes else " "
print(f"{marker} {i}. {opt}")
print()
suffix = ""
if default_indexes:
suffix = f" [{','.join(str(idx + 1) for idx in default_indexes)}]"
elif has_default:
suffix = " [none]"
while True:
raw = input(f"{prompt}{suffix}: ").strip().lower()
if raw == "" and has_default:
return default_indexes
if raw in {"none", "no", "n", "skip"}:
return []
if raw == "all":
return list(range(len(options)))
parts = [part.strip() for part in raw.replace(" ", ",").split(",") if part.strip()]
selected: list[int] = []
valid = bool(parts)
for part in parts:
if not part.isdigit():
valid = False
break
idx = int(part) - 1
if not 0 <= idx < len(options):
valid = False
break
if idx not in selected:
selected.append(idx)
if valid:
return selected
print(f" Enter comma-separated numbers between 1 and {len(options)}, 'all', or 'none'.")
def ask_text(prompt: str, default: str = "", required: bool = False) -> str: def ask_text(prompt: str, default: str = "", required: bool = False) -> str:
"""Ask for a text value, returning default if the user presses Enter.""" """Ask for a text value, returning default if the user presses Enter."""
suffix = f" [{default}]" if default else "" suffix = f" [{default}]" if default else ""
+27
View File
@@ -12,6 +12,16 @@ from typing import Any
import yaml import yaml
CHANNEL_CONNECTION_PROVIDERS: tuple[str, ...] = (
"telegram",
"slack",
"discord",
"feishu",
"dingtalk",
"wechat",
"wecom",
)
def _project_root() -> Path: def _project_root() -> Path:
return Path(__file__).resolve().parents[2] return Path(__file__).resolve().parents[2]
@@ -151,6 +161,18 @@ def _make_model_config_name(model_name: str) -> str:
return base.replace(".", "-") return base.replace(".", "-")
def _build_channel_connections_config(enabled_providers: list[str]) -> dict[str, Any]:
selected = set(enabled_providers)
unknown = selected.difference(CHANNEL_CONNECTION_PROVIDERS)
if unknown:
raise ValueError(f"Unknown channel connection provider(s): {', '.join(sorted(unknown))}")
return {
"enabled": bool(selected),
**{provider: {"enabled": provider in selected} for provider in CHANNEL_CONNECTION_PROVIDERS},
}
def build_minimal_config( def build_minimal_config(
*, *,
provider_use: str, provider_use: str,
@@ -170,6 +192,7 @@ def build_minimal_config(
allow_host_bash: bool = False, allow_host_bash: bool = False,
include_bash_tool: bool = False, include_bash_tool: bool = False,
include_write_tools: bool = True, include_write_tools: bool = True,
channel_connection_providers: list[str] | None = None,
config_version: int = 5, config_version: int = 5,
base_config: dict[str, Any] | None = None, base_config: dict[str, Any] | None = None,
) -> str: ) -> str:
@@ -219,6 +242,8 @@ def build_minimal_config(
else: else:
sandbox_config.pop("allow_host_bash", None) sandbox_config.pop("allow_host_bash", None)
data["sandbox"] = sandbox_config data["sandbox"] = sandbox_config
if channel_connection_providers is not None:
data["channel_connections"] = _build_channel_connections_config(channel_connection_providers)
header = ( header = (
f"# DeerFlow Configuration\n" f"# DeerFlow Configuration\n"
@@ -250,6 +275,7 @@ def write_config_yaml(
allow_host_bash: bool = False, allow_host_bash: bool = False,
include_bash_tool: bool = False, include_bash_tool: bool = False,
include_write_tools: bool = True, include_write_tools: bool = True,
channel_connection_providers: list[str] | None = None,
) -> None: ) -> None:
"""Write (or overwrite) config.yaml with a minimal working configuration.""" """Write (or overwrite) config.yaml with a minimal working configuration."""
# Read config_version from config.example.yaml if present # Read config_version from config.example.yaml if present
@@ -284,6 +310,7 @@ def write_config_yaml(
allow_host_bash=allow_host_bash, allow_host_bash=allow_host_bash,
include_bash_tool=include_bash_tool, include_bash_tool=include_bash_tool,
include_write_tools=include_write_tools, include_write_tools=include_write_tools,
channel_connection_providers=channel_connection_providers,
config_version=config_version, config_version=config_version,
base_config=example_defaults, base_config=example_defaults,
) )