feat(im): Add user-owned IM channel connections (#3487)

* Add user-owned IM channel connections

* Fix dev startup and channel connect popup

* Use async channel connect flow

* Harden dev service daemon startup

* Support local IM channel connections

* Align IM connections with local channels

* Fix safe user id digest algorithm

* Address Copilot IM channel feedback

* Address IM channel review comments

* Support all integrated IM channel connections

* Format additional channel connection tests

* Keep unavailable channel connect buttons clickable

* Fix IM channel provider icons

* Add runtime setup for enabled IM channels

* Guard global shortcut key handling

* Keep configured IM channels editable

* Avoid password autofill for channel secrets

* Make channel threads visible to connection owners

* Persist IM runtime config locally

* Allow disconnecting runtime IM channels

* Route no-auth channel sessions to local user

* Use default user for auth-disabled local mode

* Show IM channel source on threads

* Prefill IM channel runtime config

* Reflect IM channel runtime health

* Ignore Feishu message read events

* Ignore Feishu non-content message events

* Let setup wizard enable IM channels

* Fix frontend formatting after merge

* Stabilize backend tests without local config

* Isolate channel runtime config tests

* Address channel connection review comments

* Use sha256 user buckets with legacy migration

* Ensure runtime IM channels are ready after restart

* Persist disconnected IM channel state

* Address channel connection review comments

* Address channel connection review findings

Frontend connect flow:
- Open the runtime-config dialog only when a provider still needs
  credentials; configured providers go straight to the connect flow, so
  the binding-code/deep-link path is reachable from the UI again.
- After saving credentials, continue into the connect flow when a user
  binding is still required (multi-user mode) instead of stopping at a
  "Connected" toast.
- Extract shared provider-state helpers to core/channels/provider-state
  and add unit + e2e coverage for the direct-connect and
  configure-then-connect paths.

Provider status semantics:
- Report connection_status from the user's newest connection row;
  with no binding it is not_connected, except in auth-disabled local
  mode where a configured running channel is effectively connected.

Concurrency and event-loop correctness:
- Offload ChannelRuntimeConfigStore construction and writes, channel
  service construction, and Slack connection replies to threads; add a
  tests/blocking_io/ anchor for the runtime-config handlers.
- Consume binding codes with a conditional UPDATE so a code can only be
  used once under concurrent workers; retry upsert_connection as an
  update when a concurrent insert wins the unique constraint.
- Serialize ensure_channel_ready per channel so concurrent provider
  polls cannot double-start a channel worker.

Config and migration hardening:
- Stop mutating the get_app_config()-cached Telegram provider config;
  the runtime store now owns the UI-entered bot username.
- Register channel_connections in STARTUP_ONLY_FIELDS with the
  standardized startup-only Field description.
- Match the legacy unsafe-id bucket by recomputing its exact SHA-1 name
  so another user's same-prefix bucket can never be migrated.
- Remove the unused Telegram process_webhook_update path and document
  src/core/channels in the frontend docs.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

* Address PR review comments on authz scoping and channel runtime

Security (review feedback from ShenAC-SAC):
- Scope internal-token callers to the connection owner carried in
  X-DeerFlow-Owner-User-Id instead of bypassing owner checks outright,
  in both require_permission(owner_check=True) and the stateless run
  endpoints. Internal callers keep access to their own and
  shared/legacy threads, and may claim a default-owned channel thread
  for its real owner, but a leaked internal token no longer grants
  cross-user thread access.
- Require admin privileges for POST/DELETE /api/channels/{provider}/
  runtime-config: runtime credentials and channel workers are
  instance-wide shared state (same model as the MCP config API).
  Read-only provider listing stays available to all users.

Performance (review feedback from willem-bd):
- Skip the redundant thread channel-metadata PATCH after the first
  successful backfill per thread.
- Reuse the per-connection Slack WebClient until its token changes
  instead of constructing one per outbound message.
- Reconcile channel readiness for all providers concurrently in
  GET /api/channels/providers.

Also resolve the code-quality unused-import flag in the blocking-io
anchor by pre-importing the channel service via importlib.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

* Fix prettier formatting in provider-state test

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

* Reconcile UI runtime channel config with config reload on restart

Main now reloads a channel's config.yaml entry on restart_channel()
(#3514, issue #3497). Adapt the user-owned connection flow to coexist:

- configure_channel() restarts with reload_config=False — the caller
  just supplied the authoritative config (browser-entered credentials
  that are never written to config.yaml), so a file reload must not
  clobber it with the stale on-disk entry.
- _load_channel_config() re-applies the UI runtime-store overlay used
  at startup, so an operator-triggered restart keeps browser-entered
  credentials for channels without a config.yaml entry and does not
  resurrect a channel disconnected from the UI.
- Offload the reload's disk IO (config.yaml + runtime store) with
  asyncio.to_thread, matching the blocking-IO policy on this branch.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

---------

Co-authored-by: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
DanielWalnut
2026-06-12 15:24:58 +08:00
committed by GitHub
parent b8f5ed360f
commit aa015462a7
96 changed files with 8585 additions and 277 deletions
+11
View File
@@ -20,6 +20,17 @@ KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset(
)
def extract_connect_code(text: str) -> str | None:
"""Extract the one-time channel binding code from a connect command."""
parts = text.strip().split()
if len(parts) < 2:
return None
command = parts[0].lower()
if command in {"/connect", "connect"}:
return parts[1]
return None
def is_known_channel_command(text: str) -> bool:
"""Return whether text starts with a registered channel control command."""
if not text.startswith("/"):
@@ -0,0 +1,44 @@
"""Helpers for attaching persisted channel connection ownership to inbound messages."""
from __future__ import annotations
from typing import Any
from app.channels.message_bus import InboundMessage
async def attach_connection_identity(
inbound: InboundMessage,
*,
repo: Any,
provider: str,
workspace_id: str | None,
fallback_without_workspace: bool = False,
) -> InboundMessage:
"""Attach connection metadata to an inbound message when a persisted binding exists."""
if repo is None:
return inbound
workspace_candidates: list[str | None] = []
if workspace_id:
workspace_candidates.append(workspace_id)
if fallback_without_workspace:
workspace_candidates.append(None)
if not workspace_candidates:
return inbound
for candidate in workspace_candidates:
connection = await repo.find_connection_by_external_identity(
provider=provider,
external_account_id=inbound.user_id,
workspace_id=candidate,
)
if connection is None:
continue
inbound.connection_id = connection["id"]
inbound.owner_user_id = connection["owner_user_id"]
inbound.workspace_id = connection.get("workspace_id")
return inbound
return inbound
+105 -1
View File
@@ -14,7 +14,8 @@ from typing import Any
import httpx
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -136,6 +137,7 @@ class DingTalkChannel(Channel):
self._incoming_messages: dict[str, Any] = {}
self._incoming_messages_lock = threading.Lock()
self._card_repliers: dict[str, Any] = {}
self._connection_repo = config.get("connection_repo")
@property
def supports_streaming(self) -> bool:
@@ -395,6 +397,24 @@ class DingTalkChannel(Channel):
text[:100],
)
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(
self._bind_connection_from_connect_code(
conversation_type=conversation_type,
sender_staff_id=sender_staff_id,
sender_nick=sender_nick,
conversation_id=conversation_id,
code=connect_code,
),
self._main_loop,
)
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
else:
logger.warning("[DingTalk] main loop not running, cannot bind channel connection")
return
if _is_dingtalk_command(text):
msg_type = InboundMessageType.COMMAND
else:
@@ -450,11 +470,95 @@ class DingTalkChannel(Channel):
return ""
async def _prepare_inbound(self, chat_id: str, inbound: InboundMessage) -> None:
inbound = await self._attach_connection_identity(inbound)
# Running reply must finish before publish_inbound so AI card tracks are
# registered before the manager emits streaming outbounds.
await self._send_running_reply(chat_id, inbound)
await self.bus.publish_inbound(inbound)
@staticmethod
def _connection_workspace_id(conversation_type: str, conversation_id: str) -> str | None:
if conversation_type == _CONVERSATION_TYPE_GROUP and conversation_id:
return conversation_id
return None
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
conversation_type = str(inbound.metadata.get("conversation_type") or _CONVERSATION_TYPE_P2P)
conversation_id = str(inbound.metadata.get("conversation_id") or "")
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="dingtalk",
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
fallback_without_workspace=True,
)
async def _bind_connection_from_connect_code(
self,
*,
conversation_type: str,
sender_staff_id: str,
sender_nick: str,
conversation_id: str,
code: str,
) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="dingtalk", state=code)
if state is None:
await self._send_connection_reply(
conversation_type,
sender_staff_id,
conversation_id,
"DingTalk connection code is invalid or expired.",
)
return True
if not sender_staff_id:
await self._send_connection_reply(
conversation_type,
sender_staff_id,
conversation_id,
"DingTalk connection could not be completed from this message.",
)
return True
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="dingtalk",
external_account_id=sender_staff_id,
external_account_name=sender_nick or None,
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
metadata={
"conversation_type": conversation_type,
"conversation_id": conversation_id,
},
status="connected",
)
await self._send_connection_reply(
conversation_type,
sender_staff_id,
conversation_id,
"DingTalk connected to DeerFlow.",
)
return True
async def _send_connection_reply(
self,
conversation_type: str,
sender_staff_id: str,
conversation_id: str,
text: str,
) -> None:
robot_code = self._client_id
if conversation_type == _CONVERSATION_TYPE_GROUP:
if conversation_id:
await self._send_text_message_to_group(robot_code, conversation_id, text)
return
if sender_staff_id:
await self._send_text_message_to_user(robot_code, sender_staff_id, text)
async def _send_running_reply(self, chat_id: str, inbound: InboundMessage) -> None:
conversation_type = inbound.metadata.get("conversation_type", _CONVERSATION_TYPE_P2P)
sender_staff_id = inbound.metadata.get("sender_staff_id", "")
+64 -2
View File
@@ -10,8 +10,9 @@ from pathlib import Path
from typing import Any
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -70,6 +71,7 @@ class DiscordChannel(Channel):
self._discord_loop: asyncio.AbstractEventLoop | None = None
self._main_loop: asyncio.AbstractEventLoop | None = None
self._discord_module = None
self._connection_repo = config.get("connection_repo")
async def start(self) -> None:
if self._running:
@@ -287,6 +289,10 @@ class DiscordChannel(Channel):
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
# Don't return early if text is empty — still process the mention (e.g., create thread)
connect_code = extract_connect_code(text)
if connect_code and await self._bind_connection_from_connect_code(message, connect_code):
return
# --- Determine thread/channel routing and typing target ---
thread_id = None
chat_id = None
@@ -315,6 +321,7 @@ class DiscordChannel(Channel):
},
)
inbound.topic_id = thread_id
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
self._publish(inbound)
# Start typing indicator in the thread
if typing_target:
@@ -422,6 +429,7 @@ class DiscordChannel(Channel):
},
)
inbound.topic_id = thread_id
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
# Start typing indicator in the correct target (thread or channel)
if typing_target:
@@ -436,6 +444,60 @@ class DiscordChannel(Channel):
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
async def _attach_connection_identity(self, inbound: InboundMessage, guild_id: str | None = None) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="discord",
workspace_id=guild_id,
fallback_without_workspace=True,
)
async def _bind_connection_from_connect_code(self, message, code: str) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="discord", state=code)
if state is None:
await self._send_connection_reply(message, "Discord connection code is invalid or expired.")
return True
guild = getattr(message, "guild", None)
channel = getattr(message, "channel", None)
author = getattr(message, "author", None)
user_id = str(getattr(author, "id", "") or "")
if not user_id:
await self._send_connection_reply(message, "Discord connection could not be completed from this message.")
return True
guild_id = str(getattr(guild, "id", "") or "") or None
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="discord",
external_account_id=user_id,
external_account_name=getattr(author, "display_name", None) or getattr(author, "name", None),
workspace_id=guild_id,
workspace_name=getattr(guild, "name", None) if guild is not None else None,
metadata={
"guild_id": guild_id,
"channel_id": str(getattr(channel, "id", "") or ""),
},
status="connected",
)
await self._send_connection_reply(message, "Discord connected to DeerFlow.")
return True
@staticmethod
async def _send_connection_reply(message, text: str) -> None:
channel = getattr(message, "channel", None)
send = getattr(channel, "send", None)
if send is None:
return
try:
await send(text)
except Exception:
logger.exception("[Discord] failed to send connection reply")
def _run_client(self) -> None:
self._discord_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._discord_loop)
+78 -2
View File
@@ -11,7 +11,8 @@ import time
from typing import Any, Literal
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import (
PENDING_CLARIFICATION_METADATA_KEY,
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY,
@@ -71,6 +72,7 @@ class FeishuChannel(Channel):
self._CreateImageRequestBody = None
self._GetMessageResourceRequest = None
self._thread_lock = threading.Lock()
self._connection_repo = config.get("connection_repo")
@staticmethod
def _non_empty_str(value: Any) -> str | None:
@@ -86,6 +88,23 @@ class FeishuChannel(Channel):
def supports_streaming(self) -> bool:
return True
@property
def is_running(self) -> bool:
if not self._running:
return False
return self._thread is not None and self._thread.is_alive()
def _build_event_handler(self, lark):
return (
lark.EventDispatcherHandler.builder("", "")
.register_p2_im_message_receive_v1(self._on_message)
.register_p2_im_message_message_read_v1(self._on_ignored_message_event)
.register_p2_im_message_reaction_created_v1(self._on_ignored_message_event)
.register_p2_im_message_reaction_deleted_v1(self._on_ignored_message_event)
.register_p2_im_message_recalled_v1(self._on_ignored_message_event)
.build()
)
async def start(self) -> None:
if self._running:
return
@@ -179,7 +198,7 @@ class FeishuChannel(Channel):
# thread's uvloop.
_ws_client_mod.loop = loop
event_handler = lark.EventDispatcherHandler.builder("", "").register_p2_im_message_receive_v1(self._on_message).build()
event_handler = self._build_event_handler(lark)
ws_client = lark.ws.Client(
app_id=app_id,
app_secret=app_secret,
@@ -191,6 +210,10 @@ class FeishuChannel(Channel):
except Exception:
if self._running:
logger.exception("Feishu WebSocket error")
self._running = False
def _on_ignored_message_event(self, event) -> None:
logger.debug("[Feishu] ignoring non-content message event: %s", type(event).__name__)
async def stop(self) -> None:
self._running = False
@@ -726,11 +749,47 @@ class FeishuChannel(Channel):
async def _prepare_inbound(self, msg_id: str, inbound) -> None:
"""Kick off Feishu side effects without delaying inbound dispatch."""
inbound = await self._attach_connection_identity(inbound)
reaction_task = asyncio.create_task(self._add_reaction(msg_id, "OK"))
self._track_background_task(reaction_task, name="add_reaction", msg_id=msg_id)
self._ensure_running_card_started(msg_id)
await self.bus.publish_inbound(inbound)
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="feishu",
workspace_id=inbound.chat_id,
)
async def _bind_connection_from_connect_code(self, *, message_id: str, chat_id: str, user_id: str, code: str) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="feishu", state=code)
if state is None:
await self._reply_card(message_id, "Feishu connection code is invalid or expired.")
return True
if not user_id or not chat_id:
await self._reply_card(message_id, "Feishu connection could not be completed from this message.")
return True
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="feishu",
external_account_id=user_id,
workspace_id=chat_id,
metadata={
"chat_id": chat_id,
"message_id": message_id,
},
status="connected",
)
await self._reply_card(message_id, "Feishu connected to DeerFlow.")
return True
def _on_message(self, event) -> None:
"""Called by lark-oapi when a message is received (runs in lark thread)."""
try:
@@ -819,6 +878,23 @@ class FeishuChannel(Channel):
logger.info("[Feishu] empty text, ignoring message")
return
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(
self._bind_connection_from_connect_code(
message_id=msg_id,
chat_id=chat_id,
user_id=sender_id,
code=connect_code,
),
self._main_loop,
)
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
else:
logger.warning("[Feishu] main loop not running, cannot bind channel connection")
return
# Only treat known slash commands as commands; absolute paths and
# other slash-prefixed text should be handled as normal chat.
if _is_feishu_command(text):
+152 -30
View File
@@ -274,6 +274,22 @@ def _response_metadata(base_metadata: dict[str, Any], *, pending_clarification:
return metadata
def _thread_channel_metadata(msg: InboundMessage) -> dict[str, Any]:
channel_source: dict[str, Any] = {
"type": "im_channel",
"provider": msg.channel_name,
"chat_id": msg.chat_id,
}
if msg.topic_id:
channel_source["topic_id"] = msg.topic_id
if msg.thread_ts:
channel_source["thread_ts"] = msg.thread_ts
if msg.connection_id:
channel_source["connection_id"] = msg.connection_id
return {"channel_source": channel_source}
def _extract_text_content(content: Any) -> str:
"""Extract text from a streaming payload content field."""
if isinstance(content, str):
@@ -440,6 +456,43 @@ def _human_input_message(content: str, *, original_content: str | None = None) -
return message
def _auth_disabled_owner_user_id() -> str | None:
try:
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled
except Exception:
logger.debug("Unable to inspect auth-disabled mode for channel owner fallback", exc_info=True)
return None
return AUTH_DISABLED_USER_ID if is_auth_disabled() else None
def _effective_owner_user_id(msg: InboundMessage) -> str | None:
return _auth_disabled_owner_user_id() or msg.owner_user_id
def _apply_effective_owner(msg: InboundMessage) -> InboundMessage:
owner_user_id = _effective_owner_user_id(msg)
if owner_user_id:
msg.owner_user_id = owner_user_id
return msg
def _owner_headers(msg: InboundMessage) -> dict[str, str] | None:
owner_user_id = _effective_owner_user_id(msg)
if not owner_user_id:
return None
return create_internal_auth_headers(owner_user_id=owner_user_id)
def _safe_user_id_for_run(raw_user_id: str) -> str:
from deerflow.config.paths import get_paths
try:
return get_paths().prepare_user_dir_for_raw_id(raw_user_id)
except Exception:
logger.exception("Failed to prepare channel run user directory")
return make_safe_user_id(raw_user_id)
def _resolve_slash_skill_command(
text: str,
available_skills: set[str] | None = None,
@@ -670,6 +723,7 @@ class ChannelManager:
assistant_id: str = DEFAULT_ASSISTANT_ID,
default_session: dict[str, Any] | None = None,
channel_sessions: dict[str, Any] | None = None,
connection_repo: Any | None = None,
) -> None:
self.bus = bus
self.store = store
@@ -679,7 +733,9 @@ class ChannelManager:
self._assistant_id = assistant_id
self._default_session = _as_dict(default_session)
self._channel_sessions = dict(channel_sessions or {})
self._connection_repo = connection_repo
self._client = None # lazy init — langgraph_sdk async client
self._channel_metadata_synced: set[str] = set()
self._skill_storage: SkillStorage | None = None
self._csrf_token = generate_csrf_token()
self._semaphore: asyncio.Semaphore | None = None
@@ -728,12 +784,17 @@ class ChannelManager:
configurable["checkpoint_ns"] = ""
configurable["thread_id"] = thread_id
# ``user_id`` drives user-scoped filesystem buckets that only accept
# ``[A-Za-z0-9_-]``, so normalize the channel id and keep the raw value
# under ``channel_user_id`` for platform-facing lookups.
# ``user_id`` drives DeerFlow-owned memory, files, and thread buckets.
# For browser-connected IM channels, prefer the DeerFlow account that
# owns the connection. Preserve the raw platform user under
# ``channel_user_id`` for platform-facing lookups and audits.
run_context_identity: dict[str, Any] = {"thread_id": thread_id}
owner_user_id = _effective_owner_user_id(msg)
if owner_user_id:
run_context_identity["user_id"] = _safe_user_id_for_run(owner_user_id)
elif msg.user_id:
run_context_identity["user_id"] = _safe_user_id_for_run(msg.user_id)
if msg.user_id:
run_context_identity["user_id"] = make_safe_user_id(msg.user_id)
run_context_identity["channel_user_id"] = msg.user_id
run_context = _merge_dicts(
@@ -845,6 +906,7 @@ class ChannelManager:
logger.error("[Manager] unhandled error in message task: %s", exc, exc_info=exc)
async def _handle_message(self, msg: InboundMessage) -> None:
msg = _apply_effective_owner(msg)
async with self._semaphore:
try:
if msg.msg_type == InboundMessageType.COMMAND:
@@ -877,10 +939,27 @@ class ChannelManager:
# -- chat handling -----------------------------------------------------
async def _create_thread(self, client, msg: InboundMessage) -> str:
"""Create a new thread through Gateway and store the mapping."""
thread = await client.threads.create()
thread_id = thread["thread_id"]
async def _lookup_thread_id(self, msg: InboundMessage) -> str | None:
if msg.connection_id and self._connection_repo is not None:
return await self._connection_repo.get_thread_id(
msg.connection_id,
msg.chat_id,
msg.topic_id,
)
return self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
async def _store_thread_id(self, msg: InboundMessage, thread_id: str) -> None:
if msg.connection_id and msg.owner_user_id and self._connection_repo is not None:
await self._connection_repo.set_thread_id(
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
provider=msg.channel_name,
external_conversation_id=msg.chat_id,
external_topic_id=msg.topic_id,
thread_id=thread_id,
)
return
self.store.set_thread_id(
msg.channel_name,
msg.chat_id,
@@ -888,18 +967,49 @@ class ChannelManager:
topic_id=msg.topic_id,
user_id=msg.user_id,
)
async def _create_thread(self, client, msg: InboundMessage) -> str:
"""Create a new thread through Gateway and store the mapping."""
metadata = _thread_channel_metadata(msg)
owner_headers = _owner_headers(msg)
if owner_headers:
thread = await client.threads.create(metadata=metadata, headers=owner_headers)
else:
thread = await client.threads.create(metadata=metadata)
thread_id = thread["thread_id"]
await self._store_thread_id(msg, thread_id)
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
return thread_id
async def _update_thread_channel_metadata(self, client, msg: InboundMessage, thread_id: str) -> None:
"""Best-effort source metadata backfill for existing IM-created threads."""
# The metadata (provider/chat/topic) is constant for a thread, so one
# successful backfill per manager lifetime is enough — skip the
# redundant PATCH on every subsequent inbound message.
if thread_id in self._channel_metadata_synced:
return
update_kwargs: dict[str, Any] = {"metadata": _thread_channel_metadata(msg)}
if owner_headers := _owner_headers(msg):
update_kwargs["headers"] = owner_headers
try:
await client.threads.update(thread_id, **update_kwargs)
except Exception:
logger.debug("[Manager] failed to update channel metadata for thread_id=%s", thread_id, exc_info=True)
return
if len(self._channel_metadata_synced) > 4096:
self._channel_metadata_synced.clear()
self._channel_metadata_synced.add(thread_id)
async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None:
client = self._get_client()
# Look up existing DeerFlow thread.
# topic_id may be None (e.g. Telegram private chats) — the store
# handles this by using the "channel:chat_id" key without a topic suffix.
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
thread_id = await self._lookup_thread_id(msg)
if thread_id:
logger.info("[Manager] reusing thread: thread_id=%s for topic_id=%s", thread_id, msg.topic_id)
await self._update_thread_channel_metadata(client, msg, thread_id)
# No existing thread found — create a new one
if thread_id is None:
@@ -940,14 +1050,19 @@ class ChannelManager:
return
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
run_kwargs: dict[str, Any] = {
"input": {"messages": [human_message]},
"config": run_config,
"context": run_context,
"multitask_strategy": "reject",
}
if owner_headers := _owner_headers(msg):
run_kwargs["headers"] = owner_headers
try:
result = await client.runs.wait(
thread_id,
assistant_id,
input={"messages": [human_message]},
config=run_config,
context=run_context,
multitask_strategy="reject",
**run_kwargs,
)
except Exception as exc:
if _is_thread_busy_error(exc):
@@ -984,6 +1099,8 @@ class ChannelManager:
artifacts=artifacts,
attachments=attachments,
thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
)
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
@@ -1008,16 +1125,21 @@ class ChannelManager:
last_published_text = ""
last_publish_at = 0.0
stream_error: BaseException | None = None
stream_kwargs: dict[str, Any] = {
"input": {"messages": [human_message]},
"config": run_config,
"context": run_context,
"stream_mode": ["messages-tuple", "values"],
"multitask_strategy": "reject",
}
if owner_headers := _owner_headers(msg):
stream_kwargs["headers"] = owner_headers
try:
async for chunk in client.runs.stream(
thread_id,
assistant_id,
input={"messages": [human_message]},
config=run_config,
context=run_context,
stream_mode=["messages-tuple", "values"],
multitask_strategy="reject",
**stream_kwargs,
):
event = getattr(chunk, "event", "")
data = getattr(chunk, "data", None)
@@ -1047,6 +1169,8 @@ class ChannelManager:
text=latest_text,
is_final=False,
thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_response_metadata(msg.metadata),
)
)
@@ -1093,6 +1217,8 @@ class ChannelManager:
attachments=attachments,
is_final=True,
thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
)
)
@@ -1124,18 +1250,10 @@ class ChannelManager:
if reply is None and command == "new":
# Create a new thread through Gateway
client = self._get_client()
thread = await client.threads.create()
new_thread_id = thread["thread_id"]
self.store.set_thread_id(
msg.channel_name,
msg.chat_id,
new_thread_id,
topic_id=msg.topic_id,
user_id=msg.user_id,
)
await self._create_thread(client, msg)
reply = "New conversation started."
elif reply is None and command == "status":
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
thread_id = await self._lookup_thread_id(msg)
reply = f"Active thread: {thread_id}" if thread_id else "No active conversation."
elif reply is None and command == "models":
reply = await self._fetch_gateway("/api/models", "models")
@@ -1174,9 +1292,11 @@ class ChannelManager:
outbound = OutboundMessage(
channel_name=msg.channel_name,
chat_id=msg.chat_id,
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "",
thread_id=await self._lookup_thread_id(msg) or "",
text=reply,
thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_slim_metadata(msg.metadata),
)
await self.bus.publish_outbound(outbound)
@@ -1212,9 +1332,11 @@ class ChannelManager:
outbound = OutboundMessage(
channel_name=msg.channel_name,
chat_id=msg.chat_id,
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "",
thread_id=await self._lookup_thread_id(msg) or "",
text=error_text,
thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_slim_metadata(msg.metadata),
)
await self.bus.publish_outbound(outbound)
+14
View File
@@ -44,6 +44,12 @@ class InboundMessage:
Messages sharing the same ``topic_id`` within a ``chat_id`` will
reuse the same DeerFlow thread. When ``None``, each message
creates a new thread (one-shot Q&A).
connection_id: Optional DeerFlow channel connection id. When present,
conversation mapping is scoped by the connection instead of the
legacy global ``channel_name:chat_id[:topic_id]`` key.
owner_user_id: DeerFlow user id that owns the channel connection.
Platform user ids stay in ``user_id``.
workspace_id: Optional external workspace/guild/team id.
files: Optional list of file attachments (platform-specific dicts).
metadata: Arbitrary extra data from the channel.
created_at: Unix timestamp when the message was created.
@@ -56,6 +62,9 @@ class InboundMessage:
msg_type: InboundMessageType = InboundMessageType.CHAT
thread_ts: str | None = None
topic_id: str | None = None
connection_id: str | None = None
owner_user_id: str | None = None
workspace_id: str | None = None
files: list[dict[str, Any]] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)
created_at: float = field(default_factory=time.time)
@@ -95,6 +104,9 @@ class OutboundMessage:
is_final: Whether this is the final message in the response stream.
thread_ts: Optional platform thread identifier for threaded replies.
metadata: Arbitrary extra data.
connection_id: Optional DeerFlow channel connection id used for
connection-specific outbound credentials.
owner_user_id: DeerFlow user id that owns the channel connection.
created_at: Unix timestamp.
"""
@@ -106,6 +118,8 @@ class OutboundMessage:
attachments: list[ResolvedAttachment] = field(default_factory=list)
is_final: bool = True
thread_ts: str | None = None
connection_id: str | None = None
owner_user_id: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
created_at: float = field(default_factory=time.time)
@@ -0,0 +1,154 @@
"""Local persistence for runtime IM channel configuration."""
from __future__ import annotations
import json
import logging
import tempfile
import threading
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
RUNTIME_CHANNEL_DISABLED_FLAG = "_runtime_disabled"
class ChannelRuntimeConfigStore:
"""JSON-backed store for channel credentials entered from the UI.
This intentionally mirrors ``ChannelStore``: local/private deployments get
durable runtime configuration without needing a public callback URL or a
config.yaml edit.
"""
def __init__(self, path: str | Path | None = None) -> None:
if path is None:
from deerflow.config.paths import get_paths
path = Path(get_paths().base_dir) / "channels" / "runtime-config.json"
self._path = Path(path)
self._path.parent.mkdir(parents=True, exist_ok=True)
self._data: dict[str, dict[str, Any]] = self._load()
self._lock = threading.Lock()
def _load(self) -> dict[str, dict[str, Any]]:
if self._path.exists():
try:
raw = json.loads(self._path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
logger.warning("Corrupt channel runtime config store at %s, starting fresh", self._path)
return {}
if isinstance(raw, dict):
return {str(name): dict(value) for name, value in raw.items() if isinstance(value, dict)}
return {}
def _save(self) -> None:
fd = tempfile.NamedTemporaryFile(
mode="w",
dir=self._path.parent,
suffix=".tmp",
delete=False,
)
try:
json.dump(self._data, fd, indent=2, ensure_ascii=False)
fd.close()
Path(fd.name).replace(self._path)
try:
self._path.chmod(0o600)
except OSError:
logger.debug("Unable to chmod channel runtime config store at %s", self._path, exc_info=True)
except BaseException:
fd.close()
Path(fd.name).unlink(missing_ok=True)
raise
def load_all(self) -> dict[str, dict[str, Any]]:
with self._lock:
return {name: dict(config) for name, config in self._data.items()}
def get_provider_config(self, provider: str) -> dict[str, Any] | None:
with self._lock:
config = self._data.get(provider)
return dict(config) if isinstance(config, dict) else None
def set_provider_config(self, provider: str, config: dict[str, Any]) -> None:
with self._lock:
self._data[provider] = dict(config)
self._save()
def set_provider_disconnected(self, provider: str) -> None:
with self._lock:
self._data[provider] = {
"enabled": False,
RUNTIME_CHANNEL_DISABLED_FLAG: True,
}
self._save()
def remove_provider_config(self, provider: str) -> bool:
with self._lock:
if provider not in self._data:
return False
del self._data[provider]
self._save()
return True
def _provider_enabled(channel_connections_config: Any, provider: str) -> bool:
provider_config = getattr(channel_connections_config, provider, None)
return bool(getattr(provider_config, "enabled", False))
def _runtime_channel_disconnected(runtime_config: dict[str, Any]) -> bool:
return runtime_config.get(RUNTIME_CHANNEL_DISABLED_FLAG) is True and runtime_config.get("enabled") is False
def merge_runtime_channel_configs(
channels_config: dict[str, Any],
channel_connections_config: Any,
*,
store: ChannelRuntimeConfigStore | None = None,
) -> None:
"""Merge persisted runtime provider config into ``channels_config`` in-place."""
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
return
runtime_store = store or ChannelRuntimeConfigStore()
for provider, runtime_config in runtime_store.load_all().items():
if not _provider_enabled(channel_connections_config, provider):
continue
if _runtime_channel_disconnected(runtime_config):
channels_config.pop(provider, None)
continue
existing = channels_config.get(provider)
merged = dict(runtime_config)
if isinstance(existing, dict):
merged.update(existing)
channels_config[provider] = merged
def apply_runtime_connection_config(
channel_connections_config: Any,
*,
store: ChannelRuntimeConfigStore | None = None,
) -> Any:
"""Apply persisted connection metadata that lives outside ``channels``.
Telegram uses a bot username for deep links; UI-entered values are stored
with the runtime channel config so local restarts keep the provider
configured.
"""
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
return channel_connections_config
runtime_store = store or ChannelRuntimeConfigStore()
telegram_runtime_config = runtime_store.get_provider_config("telegram")
bot_username = ""
if isinstance(telegram_runtime_config, dict):
bot_username = str(telegram_runtime_config.get("bot_username") or "").strip()
if not bot_username or not _provider_enabled(channel_connections_config, "telegram"):
return channel_connections_config
config = channel_connections_config.model_copy(deep=True)
config.telegram.bot_username = bot_username
return config
+145 -26
View File
@@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
import logging
import os
from typing import TYPE_CHECKING, Any
@@ -9,6 +10,7 @@ from typing import TYPE_CHECKING, Any
from app.channels.base import Channel
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
from app.channels.message_bus import MessageBus
from app.channels.runtime_config_store import merge_runtime_channel_configs
from app.channels.store import ChannelStore
logger = logging.getLogger(__name__)
@@ -42,6 +44,11 @@ _CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
def _channel_has_credentials(name: str, channel_config: dict[str, Any]) -> bool:
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
return any(not isinstance(channel_config.get(key), bool) and channel_config.get(key) is not None and str(channel_config[key]).strip() for key in cred_keys)
def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str, default: str) -> str:
value = config.pop(config_key, None)
if isinstance(value, str) and value.strip():
@@ -52,6 +59,30 @@ def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str,
return default
def _merge_channel_connection_runtime_config(channels_config: dict[str, Any], app_config: AppConfig) -> None:
connection_config = getattr(app_config, "channel_connections", None)
merge_runtime_channel_configs(channels_config, connection_config)
def _make_connection_repo(app_config: AppConfig):
connection_config = getattr(app_config, "channel_connections", None)
if connection_config is None or not getattr(connection_config, "enabled", False):
return None
try:
from deerflow.persistence.channel_connections import ChannelConnectionRepository
from deerflow.persistence.engine import get_session_factory
except Exception:
logger.exception("Failed to import channel connection repository")
return None
session_factory = get_session_factory()
if session_factory is None:
logger.warning("Channel connections are enabled but database persistence is not available")
return None
return ChannelConnectionRepository(session_factory)
class ChannelService:
"""Manages the lifecycle of all configured IM channels.
@@ -59,9 +90,10 @@ class ChannelService:
instantiates enabled channels, and starts the ChannelManager dispatcher.
"""
def __init__(self, channels_config: dict[str, Any] | None = None) -> None:
def __init__(self, channels_config: dict[str, Any] | None = None, *, connection_repo: Any | None = None) -> None:
self.bus = MessageBus()
self.store = ChannelStore()
self._connection_repo = connection_repo
config = dict(channels_config or {})
langgraph_url = _resolve_service_url(config, "langgraph_url", _CHANNELS_LANGGRAPH_URL_ENV, DEFAULT_LANGGRAPH_URL)
gateway_url = _resolve_service_url(config, "gateway_url", _CHANNELS_GATEWAY_URL_ENV, DEFAULT_GATEWAY_URL)
@@ -74,10 +106,12 @@ class ChannelService:
gateway_url=gateway_url,
default_session=default_session if isinstance(default_session, dict) else None,
channel_sessions=channel_sessions,
connection_repo=connection_repo,
)
self._channels: dict[str, Any] = {} # name -> Channel instance
self._config = config
self._running = False
self._readiness_locks: dict[str, asyncio.Lock] = {}
@classmethod
def from_app_config(cls, app_config: AppConfig | None = None) -> ChannelService:
@@ -90,8 +124,9 @@ class ChannelService:
# extra fields are allowed by AppConfig (extra="allow")
extra = app_config.model_extra or {}
if "channels" in extra:
channels_config = extra["channels"]
return cls(channels_config=channels_config)
channels_config = dict(extra["channels"] or {})
_merge_channel_connection_runtime_config(channels_config, app_config)
return cls(channels_config=channels_config, connection_repo=_make_connection_repo(app_config))
async def start(self) -> None:
"""Start the manager and all enabled channels."""
@@ -99,36 +134,83 @@ class ChannelService:
return
await self.manager.start()
self._running = True
ready_status = await self.ensure_ready_channels(attempts=2)
ready_count = sum(1 for ready in ready_status.values() if ready)
logger.info("ChannelService started with %d/%d ready channels", ready_count, len(ready_status))
async def ensure_ready_channels(self, *, attempts: int = 1) -> dict[str, bool]:
"""Start or restart enabled configured channels that are not ready."""
ready_status: dict[str, bool] = {}
for name, channel_config in self._config.items():
if not isinstance(channel_config, dict):
continue
if not channel_config.get("enabled", False):
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
has_creds = any(not isinstance(channel_config.get(k), bool) and channel_config.get(k) is not None and str(channel_config[k]).strip() for k in cred_keys)
if has_creds:
if _channel_has_credentials(name, channel_config):
logger.warning(
"Channel '%s' has credentials configured but is disabled. Set enabled: true under channels.%s in config.yaml to activate it.",
name,
name,
"A configured channel has credentials configured but is disabled. Set enabled: true under its channels entry in config.yaml to activate it.",
)
else:
logger.info("Channel %s is disabled, skipping", name)
logger.info("A configured channel is disabled, skipping")
continue
await self._start_channel(name, channel_config)
ready_status[name] = await self.ensure_channel_ready(name, attempts=attempts)
return ready_status
self._running = True
logger.info("ChannelService started with channels: %s", list(self._channels.keys()))
async def ensure_channel_ready(
self,
name: str,
config: dict[str, Any] | None = None,
*,
attempts: int = 1,
) -> bool:
"""Ensure a single enabled channel is running using its current config."""
if not self._running:
logger.warning("ChannelService is not running; cannot ensure channel readiness")
return False
if config is not None:
self._config[name] = dict(config)
# Serialize per channel: readiness is polled from request handlers, so
# concurrent calls must not stop/start the same channel worker twice.
lock = self._readiness_locks.setdefault(name, asyncio.Lock())
async with lock:
channel_config = self._config.get(name)
if not channel_config or not isinstance(channel_config, dict):
logger.warning("No config for requested channel")
return False
if not channel_config.get("enabled", False):
return False
channel = self._channels.get(name)
if channel is not None and channel.is_running:
return True
if channel is not None:
try:
await channel.stop()
except Exception:
logger.exception("Error stopping non-running channel before readiness retry")
self._channels.pop(name, None)
max_attempts = max(1, attempts)
for attempt in range(max_attempts):
if attempt > 0:
logger.info("Retrying channel startup after readiness check")
if await self._start_channel(name, channel_config):
return True
return False
async def stop(self) -> None:
"""Stop all channels and the manager."""
for name, channel in list(self._channels.items()):
try:
await channel.stop()
logger.info("Channel %s stopped", name)
logger.info("Channel stopped")
except Exception:
logger.exception("Error stopping channel %s", name)
logger.exception("Error stopping channel")
self._channels.clear()
await self.manager.stop()
@@ -140,6 +222,9 @@ class ChannelService:
Uses ``get_app_config()`` which detects file changes via mtime,
so edits to ``config.yaml`` are picked up without a process restart.
The UI runtime-config overlay applied at startup is re-applied here
so a file-driven reload neither drops credentials entered from the
browser nor resurrects a channel disconnected from it.
Falls back to the cached ``self._config`` when config loading fails.
"""
try:
@@ -147,7 +232,8 @@ class ChannelService:
app_config = get_app_config()
extra = app_config.model_extra or {}
channels_config = extra.get("channels", {})
channels_config = dict(extra.get("channels") or {})
_merge_channel_connection_runtime_config(channels_config, app_config)
channel_config = channels_config.get(name)
if isinstance(channel_config, dict):
# Update the cached config so get_status() stays consistent.
@@ -157,18 +243,23 @@ class ChannelService:
logger.exception("Failed to reload config for channel %s, using cached version", name)
return self._config.get(name)
async def restart_channel(self, name: str) -> bool:
async def restart_channel(self, name: str, *, reload_config: bool = True) -> bool:
"""Restart a specific channel. Returns True if successful."""
if name in self._channels:
try:
await self._channels[name].stop()
except Exception:
logger.exception("Error stopping channel %s for restart", name)
logger.exception("Error stopping channel for restart")
del self._channels[name]
config = self._load_channel_config(name)
if reload_config:
# Reading config.yaml and the runtime store is disk IO; keep it
# off the event loop.
config = await asyncio.to_thread(self._load_channel_config, name)
else:
config = self._config.get(name)
if not config or not isinstance(config, dict):
logger.warning("No config for channel %s", name)
logger.warning("No config for requested channel")
return False
if not config.get("enabled", False):
@@ -177,11 +268,35 @@ class ChannelService:
return await self._start_channel(name, config)
async def configure_channel(self, name: str, config: dict[str, Any]) -> bool:
"""Apply runtime config for a channel and restart it if the service is running."""
self._config[name] = dict(config)
if not self._running:
return True
# The caller just supplied the authoritative config (e.g. credentials
# entered in the browser that are never written to config.yaml) — a
# file reload here would clobber it with the stale on-disk entry.
return await self.restart_channel(name, reload_config=False)
async def remove_channel(self, name: str) -> bool:
"""Remove runtime config for a channel and stop it if currently running."""
self._config.pop(name, None)
channel = self._channels.pop(name, None)
if channel is None:
return True
try:
await channel.stop()
logger.info("Channel stopped and removed")
return True
except Exception:
logger.exception("Error stopping channel for removal")
return False
async def _start_channel(self, name: str, config: dict[str, Any]) -> bool:
"""Instantiate and start a single channel."""
import_path = _CHANNEL_REGISTRY.get(name)
if not import_path:
logger.warning("Unknown channel type: %s", name)
logger.warning("Unknown channel type")
return False
try:
@@ -189,24 +304,26 @@ class ChannelService:
channel_cls = resolve_class(import_path, base_class=None)
except Exception:
logger.exception("Failed to import channel class for %s", name)
logger.exception("Failed to import channel class")
return False
try:
config = dict(config)
config["channel_store"] = self.store
if self._connection_repo is not None:
config["connection_repo"] = self._connection_repo
channel = channel_cls(bus=self.bus, config=config)
self._channels[name] = channel
await channel.start()
if not channel.is_running:
self._channels.pop(name, None)
logger.error("Channel %s did not enter a running state after start()", name)
logger.error("Channel did not enter a running state after start()")
return False
logger.info("Channel %s started", name)
logger.info("Channel started")
return True
except Exception:
self._channels.pop(name, None)
logger.exception("Failed to start channel %s", name)
logger.exception("Failed to start channel")
return False
def get_status(self) -> dict[str, Any]:
@@ -245,7 +362,9 @@ async def start_channel_service(app_config: AppConfig | None = None) -> ChannelS
global _channel_service
if _channel_service is not None:
return _channel_service
_channel_service = ChannelService.from_app_config(app_config)
# from_app_config reads the JSON channel store and runtime config files;
# keep that disk IO off the event loop.
_channel_service = await asyncio.to_thread(ChannelService.from_app_config, app_config)
await _channel_service.start()
return _channel_service
+148 -26
View File
@@ -9,7 +9,8 @@ from typing import Any
from markdown_to_mrkdwn import SlackMarkdownConverter
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -64,6 +65,9 @@ class SlackChannel(Channel):
self._web_client = None
self._loop: asyncio.AbstractEventLoop | None = None
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
self._connection_repo = config.get("connection_repo")
self._web_client_factory = config.get("web_client_factory")
self._connection_web_clients: dict[str, tuple[str, Any]] = {}
configured_bot_user_id = config.get("bot_user_id")
self._bot_user_id = str(configured_bot_user_id).lstrip("@") if configured_bot_user_id else None
@@ -80,26 +84,28 @@ class SlackChannel(Channel):
return
self._SocketModeResponse = SocketModeResponse
if self._web_client_factory is None:
self._web_client_factory = WebClient
bot_token = self.config.get("bot_token", "")
app_token = self.config.get("app_token", "")
if self._connection_repo is not None and self.config.get("event_delivery") == "http":
if not bot_token:
logger.error("Slack HTTP Events mode requires bot_token")
return
await self._initialize_operator_web_client(str(bot_token))
self._loop = asyncio.get_event_loop()
self._running = True
self.bus.subscribe_outbound(self._on_outbound)
logger.info("Slack channel started in HTTP Events mode")
return
if not bot_token or not app_token:
logger.error("Slack channel requires bot_token and app_token")
return
self._web_client = WebClient(token=bot_token)
if self._bot_user_id is None:
try:
auth_info = await asyncio.to_thread(self._web_client.auth_test)
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
if user_id is None:
auth_get = getattr(auth_info, "get", None)
user_id = auth_get("user_id") if callable(auth_get) else None
if isinstance(user_id, str) and user_id:
self._bot_user_id = user_id
except Exception:
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
await self._initialize_operator_web_client(str(bot_token))
self._socket_client = SocketModeClient(
app_token=app_token,
web_client=self._web_client,
@@ -124,7 +130,8 @@ class SlackChannel(Channel):
logger.info("Slack channel stopped")
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
if not self._web_client:
web_client = await self._get_web_client_for_message(msg)
if not web_client:
return
kwargs: dict[str, Any] = {
@@ -137,11 +144,12 @@ class SlackChannel(Channel):
last_exc: Exception | None = None
for attempt in range(_max_retries):
try:
await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs)
await asyncio.to_thread(web_client.chat_postMessage, **kwargs)
# Add a completion reaction to the thread root
if msg.thread_ts:
await asyncio.to_thread(
self._add_reaction,
self._add_reaction_with_client,
web_client,
msg.chat_id,
msg.thread_ts,
"white_check_mark",
@@ -165,7 +173,8 @@ class SlackChannel(Channel):
if msg.thread_ts:
try:
await asyncio.to_thread(
self._add_reaction,
self._add_reaction_with_client,
web_client,
msg.chat_id,
msg.thread_ts,
"x",
@@ -177,7 +186,8 @@ class SlackChannel(Channel):
raise last_exc
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
if not self._web_client:
web_client = await self._get_web_client_for_message(msg)
if not web_client:
return False
try:
@@ -190,7 +200,7 @@ class SlackChannel(Channel):
if msg.thread_ts:
kwargs["thread_ts"] = msg.thread_ts
await asyncio.to_thread(self._web_client.files_upload_v2, **kwargs)
await asyncio.to_thread(web_client.files_upload_v2, **kwargs)
logger.info("[Slack] file uploaded: %s to channel=%s", attachment.filename, msg.chat_id)
return True
except Exception:
@@ -199,12 +209,45 @@ class SlackChannel(Channel):
# -- internal ----------------------------------------------------------
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
if not self._web_client:
async def _initialize_operator_web_client(self, bot_token: str) -> None:
self._web_client = self._web_client_factory(token=bot_token)
if self._bot_user_id is not None:
return
try:
self._web_client.reactions_add(
auth_info = await asyncio.to_thread(self._web_client.auth_test)
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
if user_id is None:
auth_get = getattr(auth_info, "get", None)
user_id = auth_get("user_id") if callable(auth_get) else None
if isinstance(user_id, str) and user_id:
self._bot_user_id = user_id
except Exception:
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
async def _get_web_client_for_message(self, msg: OutboundMessage):
if msg.connection_id and self._connection_repo is not None:
credentials = await self._connection_repo.get_credentials(msg.connection_id)
access_token = credentials.get("access_token") if credentials else None
if not access_token:
return self._web_client
# WebClient keeps its own HTTP session and rate-limit state, so
# reuse one per connection until its token changes.
cached = self._connection_web_clients.get(msg.connection_id)
if cached is not None and cached[0] == access_token:
return cached[1]
if self._web_client_factory is None:
from slack_sdk import WebClient
self._web_client_factory = WebClient
web_client = self._web_client_factory(token=access_token)
self._connection_web_clients[msg.connection_id] = (access_token, web_client)
return web_client
return self._web_client
@staticmethod
def _add_reaction_with_client(web_client, channel_id: str, timestamp: str, emoji: str) -> None:
try:
web_client.reactions_add(
channel=channel_id,
timestamp=timestamp,
name=emoji,
@@ -213,6 +256,12 @@ class SlackChannel(Channel):
if "already_reacted" not in str(exc):
logger.warning("[Slack] failed to add reaction %s: %s", emoji, exc)
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
if not self._web_client:
return
self._add_reaction_with_client(self._web_client, channel_id, timestamp, emoji)
def _send_running_reply(self, channel_id: str, thread_ts: str) -> None:
"""Send a 'Working on it......' reply in the thread (called from SDK thread)."""
if not self._web_client:
@@ -249,12 +298,15 @@ class SlackChannel(Channel):
# Handle message events (DM or @mention)
if etype in ("message", "app_mention"):
self._handle_message_event(event)
self._handle_message_event(
event,
team_id=req.payload.get("team_id") or req.payload.get("team") or event.get("team"),
)
except Exception:
logger.exception("Error processing Slack event")
def _handle_message_event(self, event: dict) -> None:
def _handle_message_event(self, event: dict, *, team_id: str | None = None) -> None:
# Ignore bot messages
if event.get("bot_id") or event.get("subtype"):
return
@@ -272,6 +324,19 @@ class SlackChannel(Channel):
if not text:
return
connect_code = extract_connect_code(text)
if connect_code:
if self._loop and self._loop.is_running():
asyncio.run_coroutine_threadsafe(
self._bind_connection_from_connect_code(
event=event,
team_id=str(team_id or event.get("team") or ""),
code=connect_code,
),
self._loop,
)
return
channel_id = event.get("channel", "")
thread_ts = event.get("thread_ts") or event.get("ts", "")
@@ -297,4 +362,61 @@ class SlackChannel(Channel):
self._add_reaction(channel_id, event.get("ts", thread_ts), "eyes")
# Send "running" reply first (fire-and-forget from SDK thread)
self._send_running_reply(channel_id, thread_ts)
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
if self._connection_repo is None:
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
else:
asyncio.run_coroutine_threadsafe(self._publish_inbound_with_connection(inbound, team_id=team_id), self._loop)
async def _publish_inbound_with_connection(self, inbound, *, team_id: str | None = None) -> None:
inbound = await self._attach_connection_identity(inbound, team_id=team_id)
await self.bus.publish_inbound(inbound)
async def _attach_connection_identity(self, inbound, *, team_id: str | None = None):
workspace_id = str(team_id or inbound.metadata.get("team_id") or "")
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="slack",
workspace_id=workspace_id,
)
async def _bind_connection_from_connect_code(self, *, event: dict, team_id: str, code: str) -> bool:
if self._connection_repo is None or not code:
return False
channel_id = str(event.get("channel") or "")
thread_ts = str(event.get("thread_ts") or event.get("ts") or "")
state = await self._connection_repo.consume_oauth_state(provider="slack", state=code)
if state is None:
await self._post_connection_reply(channel_id, "Slack connection code is invalid or expired.", thread_ts)
return True
user_id = str(event.get("user") or "")
if not user_id or not team_id:
await self._post_connection_reply(channel_id, "Slack connection could not be completed from this message.", thread_ts)
return True
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="slack",
external_account_id=user_id,
workspace_id=team_id,
metadata={
"team_id": team_id,
"channel_id": channel_id,
},
status="connected",
)
await self._post_connection_reply(channel_id, "Slack connected to DeerFlow.", thread_ts)
return True
async def _post_connection_reply(self, channel_id: str, text: str, thread_ts: str | None = None) -> None:
if not self._web_client or not channel_id:
return
kwargs: dict[str, Any] = {"channel": channel_id, "text": text}
if thread_ts:
kwargs["thread_ts"] = thread_ts
try:
await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs)
except Exception:
logger.exception("[Slack] failed to send connection reply in channel=%s", channel_id)
+57
View File
@@ -8,6 +8,7 @@ import threading
from typing import Any
from app.channels.base import Channel
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -35,6 +36,7 @@ class TelegramChannel(Channel):
pass
# chat_id -> last sent message_id for threaded replies
self._last_bot_message: dict[str, int] = {}
self._connection_repo = config.get("connection_repo")
async def start(self) -> None:
if self._running:
@@ -233,6 +235,54 @@ class TelegramChannel(Channel):
return True
return user_id in self._allowed_users
@staticmethod
def _telegram_display_name(user) -> str:
full_name = getattr(user, "full_name", None)
if isinstance(full_name, str) and full_name:
return full_name
username = getattr(user, "username", None)
if isinstance(username, str) and username:
return username
return str(getattr(user, "id", ""))
async def _bind_connection_from_start_token(self, update, state_token: str) -> bool:
if self._connection_repo is None or not state_token:
return False
state = await self._connection_repo.consume_oauth_state(provider="telegram", state=state_token)
if state is None:
await update.message.reply_text("Telegram connection link is invalid or expired.")
return True
owner_user_id = state["owner_user_id"]
user_id = str(update.effective_user.id)
chat_id = str(update.effective_chat.id)
connection = await self._connection_repo.upsert_connection(
owner_user_id=owner_user_id,
provider="telegram",
external_account_id=user_id,
external_account_name=self._telegram_display_name(update.effective_user),
workspace_id=chat_id,
workspace_name=None,
metadata={
"chat_id": chat_id,
"chat_type": update.effective_chat.type,
"telegram_username": getattr(update.effective_user, "username", None),
},
status="connected",
)
logger.info("[Telegram] bound chat=%s user=%s to DeerFlow user=%s connection=%s", chat_id, user_id, owner_user_id, connection["id"])
await update.message.reply_text("Telegram connected to DeerFlow.")
return True
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="telegram",
workspace_id=inbound.chat_id,
)
def _get_bot_username(self, context) -> str | None:
bot = getattr(context, "bot", None)
username = getattr(bot, "username", None)
@@ -264,6 +314,11 @@ class TelegramChannel(Channel):
"""Handle /start command."""
if not self._check_user(update.effective_user.id):
return
args = getattr(context, "args", []) if context is not None else []
if args:
handled = await self._bind_connection_from_start_token(update, str(args[0]))
if handled:
return
await update.message.reply_text("Welcome to DeerFlow! Send me a message to start a conversation.\nType /help for available commands.")
async def _process_incoming_with_reply(self, chat_id: str, msg_id: int, inbound: InboundMessage) -> None:
@@ -299,6 +354,7 @@ class TelegramChannel(Channel):
thread_ts=msg_id,
)
inbound.topic_id = topic_id
inbound = await self._attach_connection_identity(inbound)
if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
@@ -341,6 +397,7 @@ class TelegramChannel(Channel):
thread_ts=msg_id,
)
inbound.topic_id = topic_id
inbound = await self._attach_connection_identity(inbound)
if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
+60 -2
View File
@@ -22,8 +22,9 @@ from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -253,6 +254,7 @@ class WechatChannel(Channel):
self._state_dir = self._resolve_state_dir(config.get("state_dir"))
self._cursor_path = self._state_dir / "wechat-getupdates.json" if self._state_dir else None
self._auth_path = self._state_dir / "wechat-auth.json" if self._state_dir else None
self._connection_repo = config.get("connection_repo")
self._load_state()
async def start(self) -> None:
@@ -617,6 +619,16 @@ class WechatChannel(Channel):
if thread_ts:
self._context_tokens_by_thread[thread_ts] = context_token
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
handled = await self._bind_connection_from_connect_code(
chat_id=chat_id,
context_token=context_token,
code=connect_code,
)
if handled:
return
inbound = self._make_inbound(
chat_id=chat_id,
user_id=chat_id,
@@ -632,8 +644,54 @@ class WechatChannel(Channel):
},
)
inbound.topic_id = None
inbound = await self._attach_connection_identity(inbound)
await self.bus.publish_inbound(inbound)
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="wechat",
workspace_id=inbound.chat_id,
)
async def _bind_connection_from_connect_code(self, *, chat_id: str, context_token: str, code: str) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="wechat", state=code)
if state is None:
await self._send_connection_reply(chat_id, context_token, "WeChat connection code is invalid or expired.")
return True
if not chat_id:
await self._send_connection_reply(chat_id, context_token, "WeChat connection could not be completed from this message.")
return True
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="wechat",
external_account_id=chat_id,
workspace_id=chat_id,
metadata={
"context_token": context_token,
},
status="connected",
)
await self._send_connection_reply(chat_id, context_token, "WeChat connected to DeerFlow.")
return True
async def _send_connection_reply(self, chat_id: str, context_token: str, text: str) -> None:
if not context_token:
return
await self._send_text_message(
chat_id=chat_id,
context_token=context_token,
text=text,
client_id_prefix="deerflow-connect",
max_retries=1,
)
async def _ensure_authenticated(self) -> bool:
async with self._auth_lock:
if self._bot_token:
+58 -1
View File
@@ -8,8 +8,10 @@ from collections.abc import Awaitable, Callable
from typing import Any, cast
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import (
InboundMessage,
InboundMessageType,
MessageBus,
OutboundMessage,
@@ -29,6 +31,7 @@ class WeComChannel(Channel):
self._ws_frames: dict[str, dict[str, Any]] = {}
self._ws_stream_ids: dict[str, str] = {}
self._working_message = "Working on it..."
self._connection_repo = config.get("connection_repo")
@property
def supports_streaming(self) -> bool:
@@ -271,6 +274,16 @@ class WeComChannel(Channel):
user_id = (body.get("from") or {}).get("userid")
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
handled = await self._bind_connection_from_connect_code(
frame=frame,
user_id=str(user_id or ""),
code=connect_code,
)
if handled:
return
inbound_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
inbound = self._make_inbound(
chat_id=user_id, # keep user's conversation in memory
@@ -292,8 +305,52 @@ class WeComChannel(Channel):
except Exception:
pass
inbound = await self._attach_connection_identity(inbound)
await self.bus.publish_inbound(inbound)
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="wecom",
workspace_id=str(inbound.metadata.get("aibotid") or "") or None,
fallback_without_workspace=True,
)
async def _bind_connection_from_connect_code(self, *, frame: dict[str, Any], user_id: str, code: str) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="wecom", state=code)
if state is None:
await self._send_connection_reply(frame, "WeCom connection code is invalid or expired.")
return True
if not user_id:
await self._send_connection_reply(frame, "WeCom connection could not be completed from this message.")
return True
body = frame.get("body", {}) or {}
workspace_id = str(body.get("aibotid") or "") or None
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="wecom",
external_account_id=user_id,
workspace_id=workspace_id,
metadata={
"aibotid": workspace_id,
"chattype": body.get("chattype"),
},
status="connected",
)
await self._send_connection_reply(frame, "WeCom connected to DeerFlow.")
return True
async def _send_connection_reply(self, frame: dict[str, Any], text: str) -> None:
if not self._ws_client:
return
await self._ws_client.reply(frame, {"msgtype": "text", "text": {"content": text}})
async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
if not self._ws_client:
return
+4
View File
@@ -16,6 +16,7 @@ from app.gateway.routers import (
artifacts,
assistants_compat,
auth,
channel_connections,
channels,
feedback,
mcp,
@@ -384,6 +385,9 @@ This gateway provides runtime endpoints for agent runs plus custom endpoints for
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
app.include_router(suggestions.router)
# User-facing IM channel connection API is mounted at /api/channels
app.include_router(channel_connections.router)
# Channels API is mounted at /api/channels
app.include_router(channels.router)
+4 -2
View File
@@ -6,9 +6,11 @@ import logging
import os
from types import SimpleNamespace
from deerflow.runtime.user_context import DEFAULT_USER_ID
AUTH_DISABLED_ENV_VAR = "DEER_FLOW_AUTH_DISABLED"
AUTH_DISABLED_USER_ID = "e2e-user"
AUTH_DISABLED_USER_EMAIL = "e2e@test.local"
AUTH_DISABLED_USER_ID = DEFAULT_USER_ID
AUTH_DISABLED_USER_EMAIL = "default@test.local"
AUTH_SOURCE_SESSION = "session"
AUTH_SOURCE_INTERNAL = "internal"
+18
View File
@@ -276,6 +276,8 @@ def require_permission(
# strict-deny rather than strict-allow — only an *existing*
# row with a *different* user_id triggers 404.
if owner_check:
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
thread_id = kwargs.get("thread_id")
if thread_id is None:
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
@@ -288,6 +290,22 @@ def require_permission(
str(auth.user.id),
require_existing=require_existing,
)
if not allowed and getattr(auth.user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
# Trusted internal callers (channel workers) also act for
# the connection owner carried in X-DeerFlow-Owner-User-Id.
# Scope the check to that owner instead of bypassing it; a
# leaked internal token must not grant cross-user thread
# access. The header is honored only after ``auth`` proved
# the caller holds the internal token (mirrors
# get_trusted_internal_owner_user_id, which keys off the
# middleware-stamped ``request.state.user``).
header_owner = (request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME) or "").strip()
if header_owner:
allowed = await thread_store.check_access(
thread_id,
header_owner,
require_existing=require_existing,
)
if not allowed:
raise HTTPException(
status_code=404,
+25 -2
View File
@@ -5,10 +5,12 @@ from __future__ import annotations
import os
import secrets
from types import SimpleNamespace
from typing import Any
from deerflow.runtime.user_context import DEFAULT_USER_ID
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
INTERNAL_OWNER_USER_ID_HEADER_NAME = "X-DeerFlow-Owner-User-Id"
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
INTERNAL_SYSTEM_ROLE = "internal"
@@ -23,9 +25,12 @@ def _load_internal_auth_token() -> str:
_INTERNAL_AUTH_TOKEN = _load_internal_auth_token()
def create_internal_auth_headers() -> dict[str, str]:
def create_internal_auth_headers(*, owner_user_id: str | None = None) -> dict[str, str]:
"""Return headers that authenticate trusted Gateway internal calls."""
return {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
headers = {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
if owner_user_id:
headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] = owner_user_id
return headers
def is_valid_internal_auth_token(token: str | None) -> bool:
@@ -36,3 +41,21 @@ def is_valid_internal_auth_token(token: str | None) -> bool:
def get_internal_user():
"""Return the synthetic user used for trusted internal channel calls."""
return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE)
def get_trusted_internal_owner_user_id(request: Any) -> str | None:
"""Return the owner override for a trusted internal request, if present.
The header is ignored for normal browser/API callers. It is only honored
after ``AuthMiddleware`` has validated the internal auth token and stamped
the synthetic internal user onto ``request.state.user``.
"""
user = getattr(getattr(request, "state", None), "user", None)
if getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE:
return None
owner_user_id = request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME)
if not owner_user_id:
return None
owner_user_id = owner_user_id.strip()
return owner_user_id or None
@@ -0,0 +1,670 @@
"""Browser-facing APIs for user-owned IM channel bindings."""
from __future__ import annotations
import asyncio
import logging
import secrets
from datetime import UTC, datetime, timedelta
from typing import Any
from fastapi import APIRouter, HTTPException, Request, Response
from pydantic import BaseModel, Field
from app.channels.runtime_config_store import (
ChannelRuntimeConfigStore,
apply_runtime_connection_config,
merge_runtime_channel_configs,
)
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
from deerflow.persistence.channel_connections import ChannelConnectionRepository
from deerflow.persistence.engine import get_session_factory
router = APIRouter(prefix="/api/channels", tags=["channel-connections"])
logger = logging.getLogger(__name__)
_STATE_TTL_SECONDS = 600
_MASKED_CREDENTIAL_VALUE = "********"
class ChannelCredentialFieldResponse(BaseModel):
name: str
label: str
type: str = "text"
required: bool = True
class ChannelProviderResponse(BaseModel):
provider: str
display_name: str
enabled: bool
configured: bool
connectable: bool
unavailable_reason: str | None = None
auth_mode: str
connection_status: str
credential_fields: list[ChannelCredentialFieldResponse] = Field(default_factory=list)
credential_values: dict[str, str] = Field(default_factory=dict)
class ChannelProvidersResponse(BaseModel):
enabled: bool
providers: list[ChannelProviderResponse]
class ChannelConnectionResponse(BaseModel):
id: str
provider: str
status: str
external_account_id: str | None = None
external_account_name: str | None = None
workspace_id: str | None = None
workspace_name: str | None = None
scopes: list[str] = Field(default_factory=list)
metadata: dict[str, Any] = Field(default_factory=dict)
class ChannelConnectionsResponse(BaseModel):
connections: list[ChannelConnectionResponse]
class ChannelConnectResponse(BaseModel):
provider: str
mode: str
url: str | None = None
code: str
instruction: str
expires_in: int
class ChannelRuntimeConfigRequest(BaseModel):
values: dict[str, str] = Field(default_factory=dict)
_PROVIDER_META: dict[str, dict[str, str]] = {
"telegram": {"display_name": "Telegram", "auth_mode": "deep_link"},
"slack": {"display_name": "Slack", "auth_mode": "binding_code"},
"discord": {"display_name": "Discord", "auth_mode": "binding_code"},
"feishu": {"display_name": "Feishu", "auth_mode": "binding_code"},
"dingtalk": {"display_name": "DingTalk", "auth_mode": "binding_code"},
"wechat": {"display_name": "WeChat", "auth_mode": "binding_code"},
"wecom": {"display_name": "WeCom", "auth_mode": "binding_code"},
}
_CREDENTIAL_FIELDS: dict[str, tuple[dict[str, str], ...]] = {
"telegram": (
{"name": "bot_token", "label": "Bot token", "type": "password"},
{"name": "bot_username", "label": "Bot username", "type": "text"},
),
"slack": (
{"name": "bot_token", "label": "Bot token", "type": "password"},
{"name": "app_token", "label": "App token", "type": "password"},
),
"discord": ({"name": "bot_token", "label": "Bot token", "type": "password"},),
"feishu": (
{"name": "app_id", "label": "App ID", "type": "text"},
{"name": "app_secret", "label": "App secret", "type": "password"},
),
"dingtalk": (
{"name": "client_id", "label": "Client ID", "type": "text"},
{"name": "client_secret", "label": "Client secret", "type": "password"},
),
"wechat": ({"name": "bot_token", "label": "Bot token", "type": "password"},),
"wecom": (
{"name": "bot_id", "label": "Bot ID", "type": "text"},
{"name": "bot_secret", "label": "Bot secret", "type": "password"},
),
}
_RUNTIME_REQUIREMENTS: dict[str, tuple[str, ...]] = {
"telegram": ("bot_token",),
"slack": ("bot_token", "app_token"),
"discord": ("bot_token",),
"feishu": ("app_id", "app_secret"),
"dingtalk": ("client_id", "client_secret"),
"wechat": ("bot_token",),
"wecom": ("bot_id", "bot_secret"),
}
def _get_user_id(request: Request) -> str:
user = getattr(request.state, "user", None)
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
return str(user.id)
async def _require_admin_user(request: Request) -> None:
"""Require an admin caller for instance-wide channel runtime mutations.
Runtime credentials and the channel workers they start/stop are shared by
every user of the deployment, so only admins may change them (same model
as the MCP config API). Auth-disabled local mode uses a synthetic admin
user and is unaffected.
"""
user = getattr(request.state, "user", None)
if user is None:
from app.gateway.deps import get_current_user_from_request
user = await get_current_user_from_request(request)
if getattr(user, "system_role", None) != "admin":
raise HTTPException(status_code=403, detail="Admin privileges required to manage channel runtime credentials.")
def _get_app_config():
from deerflow.config.app_config import get_app_config
return get_app_config()
async def _get_runtime_config_store(request: Request) -> ChannelRuntimeConfigStore:
store = getattr(request.app.state, "channel_runtime_config_store", None)
if isinstance(store, ChannelRuntimeConfigStore):
return store
# Constructing the store reads its JSON file from disk; keep it off the
# event loop.
store = await asyncio.to_thread(ChannelRuntimeConfigStore)
request.app.state.channel_runtime_config_store = store
return store
async def _get_channel_connections_config(request: Request) -> ChannelConnectionsConfig:
config = getattr(request.app.state, "channel_connections_config", None)
if not isinstance(config, ChannelConnectionsConfig):
config = _get_app_config().channel_connections
config = apply_runtime_connection_config(config, store=await _get_runtime_config_store(request))
request.app.state.channel_connections_config = config
return config
async def _get_channels_config(request: Request) -> dict[str, Any]:
state_config = getattr(request.app.state, "channels_config", None)
if isinstance(state_config, dict):
return state_config
result = await _load_channels_config(request, await _get_channel_connections_config(request))
request.app.state.channels_config = result
return result
async def _load_channels_config(request: Request, config: ChannelConnectionsConfig) -> dict[str, Any]:
app_config = _get_app_config()
extra = app_config.model_extra or {}
channels_config = extra.get("channels")
result = dict(channels_config) if isinstance(channels_config, dict) else {}
merge_runtime_channel_configs(
result,
config,
store=await _get_runtime_config_store(request),
)
return result
def _get_repository(request: Request, config: ChannelConnectionsConfig) -> ChannelConnectionRepository:
repo = getattr(request.app.state, "channel_connection_repo", None)
if isinstance(repo, ChannelConnectionRepository):
return repo
sf = get_session_factory()
if sf is None:
raise HTTPException(status_code=503, detail="Channel connection persistence is not available")
repo = ChannelConnectionRepository(sf)
request.app.state.channel_connection_repo = repo
return repo
def _provider_config(config: ChannelConnectionsConfig, provider: str):
provider_config = getattr(config, provider, None)
if provider_config is None:
raise HTTPException(status_code=404, detail="Unknown channel provider")
return provider_config
def _runtime_channel_configured(provider: str, channels_config: dict[str, Any]) -> bool:
runtime_config = channels_config.get(provider)
if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False):
return False
return all(str(runtime_config.get(key) or "").strip() for key in _RUNTIME_REQUIREMENTS[provider])
def _runtime_unavailable_reason(provider: str) -> str:
meta = _PROVIDER_META.get(provider)
display_name = meta["display_name"] if meta else provider
return f"Enter the required {display_name} credentials to connect this channel."
def _runtime_not_running_reason(provider: str) -> str:
meta = _PROVIDER_META.get(provider)
display_name = meta["display_name"] if meta else provider
return f"{display_name} channel is configured but is not running. Check the credentials and service logs."
def _runtime_channel_running(provider: str) -> bool | None:
try:
from app.channels.service import get_channel_service
except Exception:
logger.debug("Unable to inspect channel service status", exc_info=True)
return None
service = get_channel_service()
if service is None:
return None
try:
status = service.get_status()
except Exception:
logger.debug("Unable to read channel service status", exc_info=True)
return None
if not status.get("service_running"):
return False
channel_status = status.get("channels", {}).get(provider)
if not isinstance(channel_status, dict):
return None
return bool(channel_status.get("running"))
async def _ensure_runtime_channel_ready_if_available(
provider: str,
channels_config: dict[str, Any],
) -> bool | None:
runtime_config = channels_config.get(provider)
if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False):
return None
try:
from app.channels.service import get_channel_service
except Exception:
logger.debug("Unable to import channel service for readiness reconciliation", exc_info=True)
return None
service = get_channel_service()
if service is None:
return None
ensure_channel_ready = getattr(service, "ensure_channel_ready", None)
if ensure_channel_ready is None:
return None
try:
return await ensure_channel_ready(provider, runtime_config)
except Exception:
logger.exception("Failed to reconcile runtime channel readiness")
return False
def _provider_unavailable_reason(
config: ChannelConnectionsConfig,
channels_config: dict[str, Any],
provider: str,
) -> str | None:
provider_config = _provider_config(config, provider)
if not provider_config.enabled:
return None
if not provider_config.configured:
return _runtime_unavailable_reason(provider)
if not _runtime_channel_configured(provider, channels_config):
return _runtime_unavailable_reason(provider)
if _runtime_channel_running(provider) is False:
return _runtime_not_running_reason(provider)
return None
def _provider_status(
config: ChannelConnectionsConfig,
channels_config: dict[str, Any],
provider: str,
) -> tuple[dict[str, bool], str | None]:
declared = config.provider_status(provider)
unavailable_reason = _provider_unavailable_reason(config, channels_config, provider)
configured = declared["configured"] and _runtime_channel_configured(provider, channels_config)
return {"enabled": declared["enabled"], "configured": configured}, unavailable_reason
def _new_binding_code() -> str:
return secrets.token_urlsafe(16)
async def _create_state(
repo: ChannelConnectionRepository,
*,
owner_user_id: str,
provider: str,
) -> str:
state = _new_binding_code()
await repo.create_oauth_state(
owner_user_id=owner_user_id,
provider=provider,
state=state,
expires_at=datetime.now(UTC) + timedelta(seconds=_STATE_TTL_SECONDS),
)
return state
def _connect_instruction(provider: str, code: str) -> str:
if provider == "telegram":
return f"Send /start {code} to the DeerFlow Telegram bot."
meta = _PROVIDER_META.get(provider)
if meta is None:
raise HTTPException(status_code=404, detail="Unknown channel provider")
return f"Send /connect {code} to the DeerFlow {meta['display_name']} bot."
def _connect_url(config: ChannelConnectionsConfig, provider: str, code: str) -> str | None:
if provider == "telegram":
provider_config = _provider_config(config, provider)
return f"https://t.me/{provider_config.bot_username}?start={code}"
if _PROVIDER_META.get(provider, {}).get("auth_mode") == "binding_code":
return None
raise HTTPException(status_code=404, detail="Unknown channel provider")
def _connection_updated_at(connection: dict[str, Any]) -> datetime:
value = connection.get("updated_at")
if isinstance(value, datetime):
return value if value.tzinfo is not None else value.replace(tzinfo=UTC)
if isinstance(value, str) and value:
try:
return datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
pass
return datetime.min.replace(tzinfo=UTC)
def _newest_connection_by_provider(connections: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
by_provider: dict[str, dict[str, Any]] = {}
for item in connections:
existing = by_provider.get(item["provider"])
if existing is None or _connection_updated_at(item) > _connection_updated_at(existing):
by_provider[item["provider"]] = item
return by_provider
def _credential_fields(provider: str) -> list[ChannelCredentialFieldResponse]:
fields = _CREDENTIAL_FIELDS.get(provider)
if fields is None:
raise HTTPException(status_code=404, detail="Unknown channel provider")
return [ChannelCredentialFieldResponse(**field) for field in fields]
def _credential_values(provider: str, channels_config: dict[str, Any]) -> dict[str, str]:
runtime_config = channels_config.get(provider)
if not isinstance(runtime_config, dict):
return {}
values: dict[str, str] = {}
for field in _credential_fields(provider):
value = str(runtime_config.get(field.name) or "").strip()
if not value:
continue
values[field.name] = _MASKED_CREDENTIAL_VALUE if field.type == "password" else value
return values
def _provider_response(
config: ChannelConnectionsConfig,
channels_config: dict[str, Any],
provider: str,
meta: dict[str, str],
connection: dict[str, Any] | None = None,
) -> ChannelProviderResponse:
from app.gateway.auth_disabled import is_auth_disabled
status, unavailable_reason = _provider_status(config, channels_config, provider)
if connection:
connection_status = connection["status"]
elif is_auth_disabled() and status["configured"] and unavailable_reason is None:
# Auth-disabled local mode routes every channel message to the default
# user, so a configured running channel needs no per-user binding.
connection_status = "connected"
else:
connection_status = "not_connected"
credential_values = _credential_values(provider, channels_config)
if provider == "telegram" and not credential_values.get("bot_username"):
bot_username = str(_provider_config(config, provider).bot_username or "").strip()
if bot_username:
credential_values["bot_username"] = bot_username
return ChannelProviderResponse(
provider=provider,
display_name=meta["display_name"],
enabled=status["enabled"],
configured=status["configured"],
connectable=status["enabled"] and status["configured"] and unavailable_reason is None,
unavailable_reason=unavailable_reason,
auth_mode=meta["auth_mode"],
connection_status=connection_status,
credential_fields=_credential_fields(provider),
credential_values=credential_values,
)
def _required_runtime_values(
provider: str,
values: dict[str, str],
existing_config: dict[str, Any] | None = None,
) -> dict[str, str]:
fields = _credential_fields(provider)
cleaned: dict[str, str] = {}
missing: list[str] = []
existing_config = existing_config or {}
for field in fields:
raw_value = values.get(field.name, "")
if field.type == "password" and raw_value == _MASKED_CREDENTIAL_VALUE:
existing_value = str(existing_config.get(field.name) or "").strip()
if existing_value:
cleaned[field.name] = existing_value
continue
value = raw_value.strip() if isinstance(raw_value, str) else str(raw_value or "").strip()
if field.required and not value:
missing.append(field.label)
cleaned[field.name] = value
if missing:
raise HTTPException(status_code=400, detail=f"Missing required channel configuration: {', '.join(missing)}")
return cleaned
async def _restart_runtime_channel_if_available(provider: str, runtime_config: dict[str, Any]) -> bool | None:
try:
from app.channels.service import get_channel_service
except Exception:
logger.exception("Failed to import channel service while configuring a runtime channel")
return None
service = get_channel_service()
if service is None:
return None
return await service.configure_channel(provider, runtime_config)
async def _sync_runtime_channel_after_removal(provider: str, channels_config: dict[str, Any]) -> bool | None:
try:
from app.channels.service import get_channel_service
except Exception:
logger.exception("Failed to import channel service while disconnecting a runtime channel")
return None
service = get_channel_service()
if service is None:
return None
runtime_config = channels_config.get(provider)
if isinstance(runtime_config, dict) and runtime_config.get("enabled", False):
return await service.configure_channel(provider, runtime_config)
return await service.remove_channel(provider)
@router.get("/providers", response_model=ChannelProvidersResponse)
async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
config = await _get_channel_connections_config(request)
channels_config = await _get_channels_config(request)
repo = None
if config.enabled:
try:
repo = _get_repository(request, config)
except HTTPException as exc:
if exc.status_code != 503:
raise
owner_user_id = _get_user_id(request)
connections = await repo.list_connections(owner_user_id) if repo is not None else []
by_provider = _newest_connection_by_provider(connections)
enabled_providers = [provider for provider in _PROVIDER_META if config.provider_status(provider)["enabled"]]
# Readiness reconciliation is independent per provider; run it
# concurrently so one slow channel restart does not serialize the
# whole /providers response.
await asyncio.gather(
*(_ensure_runtime_channel_ready_if_available(provider, channels_config) for provider in enabled_providers if _runtime_channel_configured(provider, channels_config)),
)
providers: list[ChannelProviderResponse] = []
for provider in enabled_providers:
connection = by_provider.get(provider)
providers.append(_provider_response(config, channels_config, provider, _PROVIDER_META[provider], connection))
return ChannelProvidersResponse(enabled=config.enabled, providers=providers)
@router.get("/connections", response_model=ChannelConnectionsResponse)
async def get_channel_connections(request: Request) -> ChannelConnectionsResponse:
config = await _get_channel_connections_config(request)
if not config.enabled:
return ChannelConnectionsResponse(connections=[])
repo = _get_repository(request, config)
rows = await repo.list_connections(_get_user_id(request))
return ChannelConnectionsResponse(connections=[ChannelConnectionResponse(**row) for row in rows])
@router.delete("/connections/{connection_id}", status_code=204)
async def disconnect_channel_connection(connection_id: str, request: Request) -> Response:
config = await _get_channel_connections_config(request)
if not config.enabled:
raise HTTPException(status_code=400, detail="Channel connections are disabled")
repo = _get_repository(request, config)
disconnected = await repo.disconnect_connection(
connection_id=connection_id,
owner_user_id=_get_user_id(request),
)
if not disconnected:
raise HTTPException(status_code=404, detail="Channel connection not found")
return Response(status_code=204)
@router.delete("/{provider}/runtime-config", response_model=ChannelProviderResponse)
async def disconnect_channel_provider_runtime(provider: str, request: Request) -> ChannelProviderResponse:
await _require_admin_user(request)
config = await _get_channel_connections_config(request)
if not config.enabled:
raise HTTPException(status_code=400, detail="Channel connections are disabled")
provider_config = _provider_config(config, provider)
if not provider_config.enabled:
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
owner_user_id = _get_user_id(request)
try:
repo = _get_repository(request, config)
except HTTPException as exc:
if exc.status_code != 503:
raise
repo = None
if repo is not None:
for connection in await repo.list_connections(owner_user_id):
if connection["provider"] == provider and connection["status"] != "revoked":
await repo.disconnect_connection(
connection_id=connection["id"],
owner_user_id=owner_user_id,
)
store = await _get_runtime_config_store(request)
await asyncio.to_thread(store.set_provider_disconnected, provider)
channels_config = await _load_channels_config(request, config)
request.app.state.channels_config = channels_config
stopped = await _sync_runtime_channel_after_removal(provider, channels_config)
if stopped is False:
display_name = _PROVIDER_META[provider]["display_name"]
raise HTTPException(status_code=400, detail=f"Failed to stop {display_name} channel. Try again.")
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
@router.post("/{provider}/connect", response_model=ChannelConnectResponse)
async def connect_channel_provider(provider: str, request: Request) -> ChannelConnectResponse:
config = await _get_channel_connections_config(request)
channels_config = await _get_channels_config(request)
if not config.enabled:
raise HTTPException(status_code=400, detail="Channel connections are disabled")
provider_config = _provider_config(config, provider)
if provider_config.enabled and _runtime_channel_configured(provider, channels_config):
await _ensure_runtime_channel_ready_if_available(provider, channels_config)
status, unavailable_reason = _provider_status(config, channels_config, provider)
if not status["enabled"]:
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
if unavailable_reason:
raise HTTPException(status_code=400, detail=unavailable_reason)
if not status["configured"]:
raise HTTPException(status_code=400, detail="Channel provider is not configured")
repo = _get_repository(request, config)
code = await _create_state(
repo,
owner_user_id=_get_user_id(request),
provider=provider,
)
return ChannelConnectResponse(
provider=provider,
mode=_PROVIDER_META[provider]["auth_mode"],
url=_connect_url(config, provider, code),
code=code,
instruction=_connect_instruction(provider, code),
expires_in=_STATE_TTL_SECONDS,
)
@router.post("/{provider}/runtime-config", response_model=ChannelProviderResponse)
async def configure_channel_provider_runtime(
provider: str,
body: ChannelRuntimeConfigRequest,
request: Request,
) -> ChannelProviderResponse:
await _require_admin_user(request)
config = await _get_channel_connections_config(request)
if not config.enabled:
raise HTTPException(status_code=400, detail="Channel connections are disabled")
provider_config = _provider_config(config, provider)
if not provider_config.enabled:
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
channels_config = await _get_channels_config(request)
existing = channels_config.get(provider)
runtime_config = dict(existing) if isinstance(existing, dict) else {}
values = _required_runtime_values(provider, body.values, runtime_config)
runtime_config["enabled"] = True
for key in _RUNTIME_REQUIREMENTS[provider]:
runtime_config[key] = values[key]
if provider == "telegram":
# The deep-link username is persisted with the runtime channel config
# (set_provider_config below) and applied to future requests via
# apply_runtime_connection_config; never mutate the config instance
# cached by get_app_config().
runtime_config["bot_username"] = values["bot_username"]
channels_config[provider] = runtime_config
request.app.state.channels_config = channels_config
started = await _restart_runtime_channel_if_available(provider, runtime_config)
if started is False:
display_name = _PROVIDER_META[provider]["display_name"]
raise HTTPException(status_code=400, detail=f"Failed to start {display_name} channel. Check the values and try again.")
store = await _get_runtime_config_store(request)
await asyncio.to_thread(store.set_provider_config, provider, runtime_config)
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
+11 -1
View File
@@ -22,6 +22,7 @@ from pydantic import BaseModel, Field, field_validator
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer
from app.gateway.internal_auth import get_trusted_internal_owner_user_id
from app.gateway.utils import sanitize_log_param
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime import serialize_channel_values
@@ -257,11 +258,19 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
thread_store = get_thread_store(request)
thread_id = body.thread_id or str(uuid.uuid4())
now = now_iso()
thread_owner_user_id = get_trusted_internal_owner_user_id(request)
thread_owner_kwargs = {"user_id": thread_owner_user_id} if thread_owner_user_id else {}
# ``body.metadata`` is already stripped of server-reserved keys by
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
# Idempotency: return existing record when already present
existing_record = await thread_store.get(thread_id)
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
if existing_record is None and thread_owner_user_id:
unscoped_record = await thread_store.get(thread_id, user_id=None)
if unscoped_record is not None:
if unscoped_record.get("user_id") != thread_owner_user_id:
await thread_store.update_owner(thread_id, thread_owner_user_id, user_id=None)
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
if existing_record is not None:
return ThreadResponse(
thread_id=thread_id,
@@ -276,6 +285,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
await thread_store.create(
thread_id,
assistant_id=getattr(body, "assistant_id", None),
**thread_owner_kwargs,
metadata=body.metadata,
)
except Exception:
+84 -61
View File
@@ -12,6 +12,7 @@ import json
import logging
import re
from collections.abc import Mapping
from types import SimpleNamespace
from typing import Any
from fastapi import HTTPException, Request
@@ -19,7 +20,7 @@ from langchain_core.messages import BaseMessage
from langchain_core.messages.utils import convert_to_messages
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE, get_trusted_internal_owner_user_id
from app.gateway.utils import sanitize_log_param
from deerflow.config.app_config import get_app_config
from deerflow.runtime import (
@@ -35,6 +36,7 @@ from deerflow.runtime import (
run_agent,
)
from deerflow.runtime.runs.naming import resolve_root_run_name
from deerflow.runtime.user_context import reset_current_user, set_current_user
logger = logging.getLogger(__name__)
@@ -315,6 +317,7 @@ async def start_run(
detail=f"Model {model_name!r} is not in the configured model allowlist",
)
owner_user_id = get_trusted_internal_owner_user_id(request)
# Stateless run endpoints carry thread_id in the request *body*, so the
# @require_permission(owner_check=True) decorator -- which resolves ownership
# from the path param -- cannot protect them. Enforce thread ownership here,
@@ -323,79 +326,99 @@ async def start_run(
# temp threads) and NULL-owner rows (shared / pre-auth data) stay accessible
# via check_access; only a thread already owned by another user is rejected
# with 404, matching thread_runs.py's anti-enumeration behaviour. Internal
# channel runs act on behalf of IM users they do not own (see
# inject_authenticated_user_context), so the internal system role is exempt.
# channel runs act on behalf of the connection owner carried in
# X-DeerFlow-Owner-User-Id, so they are scoped to that owner instead of
# bypassing the check -- a leaked internal token must not grant cross-user
# thread access.
user = getattr(request.state, "user", None)
if user is not None and getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE:
if not await run_ctx.thread_store.check_access(thread_id, str(user.id)):
if user is not None:
allowed = await run_ctx.thread_store.check_access(thread_id, str(user.id))
if not allowed and owner_user_id and getattr(user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
# Channel workers may also act for the connection owner named in
# the trusted header (e.g. claiming a legacy default-owned channel
# thread for its real owner).
allowed = await run_ctx.thread_store.check_access(thread_id, owner_user_id)
if not allowed:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
owner_context_token = set_current_user(SimpleNamespace(id=owner_user_id)) if owner_user_id else None
try:
record = await run_mgr.create_or_reject(
thread_id,
body.assistant_id,
on_disconnect=disconnect,
metadata=body.metadata or {},
kwargs={"input": body.input, "config": body.config},
multitask_strategy=body.multitask_strategy,
model_name=model_name,
)
except ConflictError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
except UnsupportedStrategyError as exc:
raise HTTPException(status_code=501, detail=str(exc)) from exc
# Upsert thread metadata so the thread appears in /threads/search,
# even for threads that were never explicitly created via POST /threads
# (e.g. stateless runs).
try:
existing = await run_ctx.thread_store.get(thread_id)
if existing is None:
await run_ctx.thread_store.create(
try:
record = await run_mgr.create_or_reject(
thread_id,
assistant_id=body.assistant_id,
metadata=body.metadata,
body.assistant_id,
on_disconnect=disconnect,
metadata=body.metadata or {},
kwargs={"input": body.input, "config": body.config},
multitask_strategy=body.multitask_strategy,
model_name=model_name,
user_id=owner_user_id,
)
else:
await run_ctx.thread_store.update_status(thread_id, "running")
except Exception:
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
except ConflictError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
except UnsupportedStrategyError as exc:
raise HTTPException(status_code=501, detail=str(exc)) from exc
agent_factory = resolve_agent_factory(body.assistant_id)
graph_input = normalize_input(body.input)
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
# Upsert thread metadata so the thread appears in /threads/search,
# even for threads that were never explicitly created via POST /threads
# (e.g. stateless runs).
try:
existing = await run_ctx.thread_store.get(thread_id)
if existing is None and owner_user_id:
unscoped_existing = await run_ctx.thread_store.get(thread_id, user_id=None)
if unscoped_existing is not None:
if unscoped_existing.get("user_id") != owner_user_id:
await run_ctx.thread_store.update_owner(thread_id, owner_user_id, user_id=None)
existing = await run_ctx.thread_store.get(thread_id)
if existing is None:
await run_ctx.thread_store.create(
thread_id,
assistant_id=body.assistant_id,
metadata=body.metadata,
)
else:
await run_ctx.thread_store.update_status(thread_id, "running")
except Exception:
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
# The ``context`` field is a custom extension for the langgraph-compat layer
# that carries agent configuration (model_name, thinking_enabled, etc.).
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
merge_run_context_overrides(config, getattr(body, "context", None))
inject_authenticated_user_context(config, request)
agent_factory = resolve_agent_factory(body.assistant_id)
graph_input = normalize_input(body.input)
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
stream_modes = normalize_stream_modes(body.stream_mode)
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
# The ``context`` field is a custom extension for the langgraph-compat layer
# that carries agent configuration (model_name, thinking_enabled, etc.).
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
merge_run_context_overrides(config, getattr(body, "context", None))
inject_authenticated_user_context(config, request)
task = asyncio.create_task(
run_agent(
bridge,
run_mgr,
record,
ctx=run_ctx,
agent_factory=agent_factory,
graph_input=graph_input,
config=config,
stream_modes=stream_modes,
stream_subgraphs=body.stream_subgraphs,
interrupt_before=body.interrupt_before,
interrupt_after=body.interrupt_after,
stream_modes = normalize_stream_modes(body.stream_mode)
task = asyncio.create_task(
run_agent(
bridge,
run_mgr,
record,
ctx=run_ctx,
agent_factory=agent_factory,
graph_input=graph_input,
config=config,
stream_modes=stream_modes,
stream_subgraphs=body.stream_subgraphs,
interrupt_before=body.interrupt_before,
interrupt_after=body.interrupt_after,
)
)
)
record.task = task
record.task = task
# Title sync is handled by worker.py's finally block which reads the
# title from the checkpoint and calls thread_store.update_display_name
# after the run completes.
# Title sync is handled by worker.py's finally block which reads the
# title from the checkpoint and calls thread_store.update_display_name
# after the run completes.
return record
return record
finally:
if owner_context_token is not None:
reset_current_user(owner_context_token)
async def sse_consumer(