feat(im): Add user-owned IM channel connections (#3487)

* Add user-owned IM channel connections

* Fix dev startup and channel connect popup

* Use async channel connect flow

* Harden dev service daemon startup

* Support local IM channel connections

* Align IM connections with local channels

* Fix safe user id digest algorithm

* Address Copilot IM channel feedback

* Address IM channel review comments

* Support all integrated IM channel connections

* Format additional channel connection tests

* Keep unavailable channel connect buttons clickable

* Fix IM channel provider icons

* Add runtime setup for enabled IM channels

* Guard global shortcut key handling

* Keep configured IM channels editable

* Avoid password autofill for channel secrets

* Make channel threads visible to connection owners

* Persist IM runtime config locally

* Allow disconnecting runtime IM channels

* Route no-auth channel sessions to local user

* Use default user for auth-disabled local mode

* Show IM channel source on threads

* Prefill IM channel runtime config

* Reflect IM channel runtime health

* Ignore Feishu message read events

* Ignore Feishu non-content message events

* Let setup wizard enable IM channels

* Fix frontend formatting after merge

* Stabilize backend tests without local config

* Isolate channel runtime config tests

* Address channel connection review comments

* Use sha256 user buckets with legacy migration

* Ensure runtime IM channels are ready after restart

* Persist disconnected IM channel state

* Address channel connection review comments

* Address channel connection review findings

Frontend connect flow:
- Open the runtime-config dialog only when a provider still needs
  credentials; configured providers go straight to the connect flow, so
  the binding-code/deep-link path is reachable from the UI again.
- After saving credentials, continue into the connect flow when a user
  binding is still required (multi-user mode) instead of stopping at a
  "Connected" toast.
- Extract shared provider-state helpers to core/channels/provider-state
  and add unit + e2e coverage for the direct-connect and
  configure-then-connect paths.

Provider status semantics:
- Report connection_status from the user's newest connection row;
  with no binding it is not_connected, except in auth-disabled local
  mode where a configured running channel is effectively connected.

Concurrency and event-loop correctness:
- Offload ChannelRuntimeConfigStore construction and writes, channel
  service construction, and Slack connection replies to threads; add a
  tests/blocking_io/ anchor for the runtime-config handlers.
- Consume binding codes with a conditional UPDATE so a code can only be
  used once under concurrent workers; retry upsert_connection as an
  update when a concurrent insert wins the unique constraint.
- Serialize ensure_channel_ready per channel so concurrent provider
  polls cannot double-start a channel worker.

Config and migration hardening:
- Stop mutating the get_app_config()-cached Telegram provider config;
  the runtime store now owns the UI-entered bot username.
- Register channel_connections in STARTUP_ONLY_FIELDS with the
  standardized startup-only Field description.
- Match the legacy unsafe-id bucket by recomputing its exact SHA-1 name
  so another user's same-prefix bucket can never be migrated.
- Remove the unused Telegram process_webhook_update path and document
  src/core/channels in the frontend docs.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

* Address PR review comments on authz scoping and channel runtime

Security (review feedback from ShenAC-SAC):
- Scope internal-token callers to the connection owner carried in
  X-DeerFlow-Owner-User-Id instead of bypassing owner checks outright,
  in both require_permission(owner_check=True) and the stateless run
  endpoints. Internal callers keep access to their own and
  shared/legacy threads, and may claim a default-owned channel thread
  for its real owner, but a leaked internal token no longer grants
  cross-user thread access.
- Require admin privileges for POST/DELETE /api/channels/{provider}/
  runtime-config: runtime credentials and channel workers are
  instance-wide shared state (same model as the MCP config API).
  Read-only provider listing stays available to all users.

Performance (review feedback from willem-bd):
- Skip the redundant thread channel-metadata PATCH after the first
  successful backfill per thread.
- Reuse the per-connection Slack WebClient until its token changes
  instead of constructing one per outbound message.
- Reconcile channel readiness for all providers concurrently in
  GET /api/channels/providers.

Also resolve the code-quality unused-import flag in the blocking-io
anchor by pre-importing the channel service via importlib.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

* Fix prettier formatting in provider-state test

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

* Reconcile UI runtime channel config with config reload on restart

Main now reloads a channel's config.yaml entry on restart_channel()
(#3514, issue #3497). Adapt the user-owned connection flow to coexist:

- configure_channel() restarts with reload_config=False — the caller
  just supplied the authoritative config (browser-entered credentials
  that are never written to config.yaml), so a file reload must not
  clobber it with the stale on-disk entry.
- _load_channel_config() re-applies the UI runtime-store overlay used
  at startup, so an operator-triggered restart keeps browser-entered
  credentials for channels without a config.yaml entry and does not
  resurrect a channel disconnected from the UI.
- Offload the reload's disk IO (config.yaml + runtime store) with
  asyncio.to_thread, matching the blocking-IO policy on this branch.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

---------

Co-authored-by: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
DanielWalnut
2026-06-12 15:24:58 +08:00
committed by GitHub
parent b8f5ed360f
commit aa015462a7
96 changed files with 8585 additions and 277 deletions
+11
View File
@@ -20,6 +20,17 @@ KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset(
)
def extract_connect_code(text: str) -> str | None:
"""Extract the one-time channel binding code from a connect command."""
parts = text.strip().split()
if len(parts) < 2:
return None
command = parts[0].lower()
if command in {"/connect", "connect"}:
return parts[1]
return None
def is_known_channel_command(text: str) -> bool:
"""Return whether text starts with a registered channel control command."""
if not text.startswith("/"):
@@ -0,0 +1,44 @@
"""Helpers for attaching persisted channel connection ownership to inbound messages."""
from __future__ import annotations
from typing import Any
from app.channels.message_bus import InboundMessage
async def attach_connection_identity(
inbound: InboundMessage,
*,
repo: Any,
provider: str,
workspace_id: str | None,
fallback_without_workspace: bool = False,
) -> InboundMessage:
"""Attach connection metadata to an inbound message when a persisted binding exists."""
if repo is None:
return inbound
workspace_candidates: list[str | None] = []
if workspace_id:
workspace_candidates.append(workspace_id)
if fallback_without_workspace:
workspace_candidates.append(None)
if not workspace_candidates:
return inbound
for candidate in workspace_candidates:
connection = await repo.find_connection_by_external_identity(
provider=provider,
external_account_id=inbound.user_id,
workspace_id=candidate,
)
if connection is None:
continue
inbound.connection_id = connection["id"]
inbound.owner_user_id = connection["owner_user_id"]
inbound.workspace_id = connection.get("workspace_id")
return inbound
return inbound
+105 -1
View File
@@ -14,7 +14,8 @@ from typing import Any
import httpx
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -136,6 +137,7 @@ class DingTalkChannel(Channel):
self._incoming_messages: dict[str, Any] = {}
self._incoming_messages_lock = threading.Lock()
self._card_repliers: dict[str, Any] = {}
self._connection_repo = config.get("connection_repo")
@property
def supports_streaming(self) -> bool:
@@ -395,6 +397,24 @@ class DingTalkChannel(Channel):
text[:100],
)
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(
self._bind_connection_from_connect_code(
conversation_type=conversation_type,
sender_staff_id=sender_staff_id,
sender_nick=sender_nick,
conversation_id=conversation_id,
code=connect_code,
),
self._main_loop,
)
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
else:
logger.warning("[DingTalk] main loop not running, cannot bind channel connection")
return
if _is_dingtalk_command(text):
msg_type = InboundMessageType.COMMAND
else:
@@ -450,11 +470,95 @@ class DingTalkChannel(Channel):
return ""
async def _prepare_inbound(self, chat_id: str, inbound: InboundMessage) -> None:
inbound = await self._attach_connection_identity(inbound)
# Running reply must finish before publish_inbound so AI card tracks are
# registered before the manager emits streaming outbounds.
await self._send_running_reply(chat_id, inbound)
await self.bus.publish_inbound(inbound)
@staticmethod
def _connection_workspace_id(conversation_type: str, conversation_id: str) -> str | None:
if conversation_type == _CONVERSATION_TYPE_GROUP and conversation_id:
return conversation_id
return None
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
conversation_type = str(inbound.metadata.get("conversation_type") or _CONVERSATION_TYPE_P2P)
conversation_id = str(inbound.metadata.get("conversation_id") or "")
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="dingtalk",
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
fallback_without_workspace=True,
)
async def _bind_connection_from_connect_code(
self,
*,
conversation_type: str,
sender_staff_id: str,
sender_nick: str,
conversation_id: str,
code: str,
) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="dingtalk", state=code)
if state is None:
await self._send_connection_reply(
conversation_type,
sender_staff_id,
conversation_id,
"DingTalk connection code is invalid or expired.",
)
return True
if not sender_staff_id:
await self._send_connection_reply(
conversation_type,
sender_staff_id,
conversation_id,
"DingTalk connection could not be completed from this message.",
)
return True
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="dingtalk",
external_account_id=sender_staff_id,
external_account_name=sender_nick or None,
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
metadata={
"conversation_type": conversation_type,
"conversation_id": conversation_id,
},
status="connected",
)
await self._send_connection_reply(
conversation_type,
sender_staff_id,
conversation_id,
"DingTalk connected to DeerFlow.",
)
return True
async def _send_connection_reply(
self,
conversation_type: str,
sender_staff_id: str,
conversation_id: str,
text: str,
) -> None:
robot_code = self._client_id
if conversation_type == _CONVERSATION_TYPE_GROUP:
if conversation_id:
await self._send_text_message_to_group(robot_code, conversation_id, text)
return
if sender_staff_id:
await self._send_text_message_to_user(robot_code, sender_staff_id, text)
async def _send_running_reply(self, chat_id: str, inbound: InboundMessage) -> None:
conversation_type = inbound.metadata.get("conversation_type", _CONVERSATION_TYPE_P2P)
sender_staff_id = inbound.metadata.get("sender_staff_id", "")
+64 -2
View File
@@ -10,8 +10,9 @@ from pathlib import Path
from typing import Any
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -70,6 +71,7 @@ class DiscordChannel(Channel):
self._discord_loop: asyncio.AbstractEventLoop | None = None
self._main_loop: asyncio.AbstractEventLoop | None = None
self._discord_module = None
self._connection_repo = config.get("connection_repo")
async def start(self) -> None:
if self._running:
@@ -287,6 +289,10 @@ class DiscordChannel(Channel):
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
# Don't return early if text is empty — still process the mention (e.g., create thread)
connect_code = extract_connect_code(text)
if connect_code and await self._bind_connection_from_connect_code(message, connect_code):
return
# --- Determine thread/channel routing and typing target ---
thread_id = None
chat_id = None
@@ -315,6 +321,7 @@ class DiscordChannel(Channel):
},
)
inbound.topic_id = thread_id
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
self._publish(inbound)
# Start typing indicator in the thread
if typing_target:
@@ -422,6 +429,7 @@ class DiscordChannel(Channel):
},
)
inbound.topic_id = thread_id
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
# Start typing indicator in the correct target (thread or channel)
if typing_target:
@@ -436,6 +444,60 @@ class DiscordChannel(Channel):
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
async def _attach_connection_identity(self, inbound: InboundMessage, guild_id: str | None = None) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="discord",
workspace_id=guild_id,
fallback_without_workspace=True,
)
async def _bind_connection_from_connect_code(self, message, code: str) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="discord", state=code)
if state is None:
await self._send_connection_reply(message, "Discord connection code is invalid or expired.")
return True
guild = getattr(message, "guild", None)
channel = getattr(message, "channel", None)
author = getattr(message, "author", None)
user_id = str(getattr(author, "id", "") or "")
if not user_id:
await self._send_connection_reply(message, "Discord connection could not be completed from this message.")
return True
guild_id = str(getattr(guild, "id", "") or "") or None
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="discord",
external_account_id=user_id,
external_account_name=getattr(author, "display_name", None) or getattr(author, "name", None),
workspace_id=guild_id,
workspace_name=getattr(guild, "name", None) if guild is not None else None,
metadata={
"guild_id": guild_id,
"channel_id": str(getattr(channel, "id", "") or ""),
},
status="connected",
)
await self._send_connection_reply(message, "Discord connected to DeerFlow.")
return True
@staticmethod
async def _send_connection_reply(message, text: str) -> None:
channel = getattr(message, "channel", None)
send = getattr(channel, "send", None)
if send is None:
return
try:
await send(text)
except Exception:
logger.exception("[Discord] failed to send connection reply")
def _run_client(self) -> None:
self._discord_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._discord_loop)
+78 -2
View File
@@ -11,7 +11,8 @@ import time
from typing import Any, Literal
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import (
PENDING_CLARIFICATION_METADATA_KEY,
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY,
@@ -71,6 +72,7 @@ class FeishuChannel(Channel):
self._CreateImageRequestBody = None
self._GetMessageResourceRequest = None
self._thread_lock = threading.Lock()
self._connection_repo = config.get("connection_repo")
@staticmethod
def _non_empty_str(value: Any) -> str | None:
@@ -86,6 +88,23 @@ class FeishuChannel(Channel):
def supports_streaming(self) -> bool:
return True
@property
def is_running(self) -> bool:
if not self._running:
return False
return self._thread is not None and self._thread.is_alive()
def _build_event_handler(self, lark):
return (
lark.EventDispatcherHandler.builder("", "")
.register_p2_im_message_receive_v1(self._on_message)
.register_p2_im_message_message_read_v1(self._on_ignored_message_event)
.register_p2_im_message_reaction_created_v1(self._on_ignored_message_event)
.register_p2_im_message_reaction_deleted_v1(self._on_ignored_message_event)
.register_p2_im_message_recalled_v1(self._on_ignored_message_event)
.build()
)
async def start(self) -> None:
if self._running:
return
@@ -179,7 +198,7 @@ class FeishuChannel(Channel):
# thread's uvloop.
_ws_client_mod.loop = loop
event_handler = lark.EventDispatcherHandler.builder("", "").register_p2_im_message_receive_v1(self._on_message).build()
event_handler = self._build_event_handler(lark)
ws_client = lark.ws.Client(
app_id=app_id,
app_secret=app_secret,
@@ -191,6 +210,10 @@ class FeishuChannel(Channel):
except Exception:
if self._running:
logger.exception("Feishu WebSocket error")
self._running = False
def _on_ignored_message_event(self, event) -> None:
logger.debug("[Feishu] ignoring non-content message event: %s", type(event).__name__)
async def stop(self) -> None:
self._running = False
@@ -726,11 +749,47 @@ class FeishuChannel(Channel):
async def _prepare_inbound(self, msg_id: str, inbound) -> None:
"""Kick off Feishu side effects without delaying inbound dispatch."""
inbound = await self._attach_connection_identity(inbound)
reaction_task = asyncio.create_task(self._add_reaction(msg_id, "OK"))
self._track_background_task(reaction_task, name="add_reaction", msg_id=msg_id)
self._ensure_running_card_started(msg_id)
await self.bus.publish_inbound(inbound)
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="feishu",
workspace_id=inbound.chat_id,
)
async def _bind_connection_from_connect_code(self, *, message_id: str, chat_id: str, user_id: str, code: str) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="feishu", state=code)
if state is None:
await self._reply_card(message_id, "Feishu connection code is invalid or expired.")
return True
if not user_id or not chat_id:
await self._reply_card(message_id, "Feishu connection could not be completed from this message.")
return True
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="feishu",
external_account_id=user_id,
workspace_id=chat_id,
metadata={
"chat_id": chat_id,
"message_id": message_id,
},
status="connected",
)
await self._reply_card(message_id, "Feishu connected to DeerFlow.")
return True
def _on_message(self, event) -> None:
"""Called by lark-oapi when a message is received (runs in lark thread)."""
try:
@@ -819,6 +878,23 @@ class FeishuChannel(Channel):
logger.info("[Feishu] empty text, ignoring message")
return
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(
self._bind_connection_from_connect_code(
message_id=msg_id,
chat_id=chat_id,
user_id=sender_id,
code=connect_code,
),
self._main_loop,
)
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
else:
logger.warning("[Feishu] main loop not running, cannot bind channel connection")
return
# Only treat known slash commands as commands; absolute paths and
# other slash-prefixed text should be handled as normal chat.
if _is_feishu_command(text):
+152 -30
View File
@@ -274,6 +274,22 @@ def _response_metadata(base_metadata: dict[str, Any], *, pending_clarification:
return metadata
def _thread_channel_metadata(msg: InboundMessage) -> dict[str, Any]:
channel_source: dict[str, Any] = {
"type": "im_channel",
"provider": msg.channel_name,
"chat_id": msg.chat_id,
}
if msg.topic_id:
channel_source["topic_id"] = msg.topic_id
if msg.thread_ts:
channel_source["thread_ts"] = msg.thread_ts
if msg.connection_id:
channel_source["connection_id"] = msg.connection_id
return {"channel_source": channel_source}
def _extract_text_content(content: Any) -> str:
"""Extract text from a streaming payload content field."""
if isinstance(content, str):
@@ -440,6 +456,43 @@ def _human_input_message(content: str, *, original_content: str | None = None) -
return message
def _auth_disabled_owner_user_id() -> str | None:
try:
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled
except Exception:
logger.debug("Unable to inspect auth-disabled mode for channel owner fallback", exc_info=True)
return None
return AUTH_DISABLED_USER_ID if is_auth_disabled() else None
def _effective_owner_user_id(msg: InboundMessage) -> str | None:
return _auth_disabled_owner_user_id() or msg.owner_user_id
def _apply_effective_owner(msg: InboundMessage) -> InboundMessage:
owner_user_id = _effective_owner_user_id(msg)
if owner_user_id:
msg.owner_user_id = owner_user_id
return msg
def _owner_headers(msg: InboundMessage) -> dict[str, str] | None:
owner_user_id = _effective_owner_user_id(msg)
if not owner_user_id:
return None
return create_internal_auth_headers(owner_user_id=owner_user_id)
def _safe_user_id_for_run(raw_user_id: str) -> str:
from deerflow.config.paths import get_paths
try:
return get_paths().prepare_user_dir_for_raw_id(raw_user_id)
except Exception:
logger.exception("Failed to prepare channel run user directory")
return make_safe_user_id(raw_user_id)
def _resolve_slash_skill_command(
text: str,
available_skills: set[str] | None = None,
@@ -670,6 +723,7 @@ class ChannelManager:
assistant_id: str = DEFAULT_ASSISTANT_ID,
default_session: dict[str, Any] | None = None,
channel_sessions: dict[str, Any] | None = None,
connection_repo: Any | None = None,
) -> None:
self.bus = bus
self.store = store
@@ -679,7 +733,9 @@ class ChannelManager:
self._assistant_id = assistant_id
self._default_session = _as_dict(default_session)
self._channel_sessions = dict(channel_sessions or {})
self._connection_repo = connection_repo
self._client = None # lazy init — langgraph_sdk async client
self._channel_metadata_synced: set[str] = set()
self._skill_storage: SkillStorage | None = None
self._csrf_token = generate_csrf_token()
self._semaphore: asyncio.Semaphore | None = None
@@ -728,12 +784,17 @@ class ChannelManager:
configurable["checkpoint_ns"] = ""
configurable["thread_id"] = thread_id
# ``user_id`` drives user-scoped filesystem buckets that only accept
# ``[A-Za-z0-9_-]``, so normalize the channel id and keep the raw value
# under ``channel_user_id`` for platform-facing lookups.
# ``user_id`` drives DeerFlow-owned memory, files, and thread buckets.
# For browser-connected IM channels, prefer the DeerFlow account that
# owns the connection. Preserve the raw platform user under
# ``channel_user_id`` for platform-facing lookups and audits.
run_context_identity: dict[str, Any] = {"thread_id": thread_id}
owner_user_id = _effective_owner_user_id(msg)
if owner_user_id:
run_context_identity["user_id"] = _safe_user_id_for_run(owner_user_id)
elif msg.user_id:
run_context_identity["user_id"] = _safe_user_id_for_run(msg.user_id)
if msg.user_id:
run_context_identity["user_id"] = make_safe_user_id(msg.user_id)
run_context_identity["channel_user_id"] = msg.user_id
run_context = _merge_dicts(
@@ -845,6 +906,7 @@ class ChannelManager:
logger.error("[Manager] unhandled error in message task: %s", exc, exc_info=exc)
async def _handle_message(self, msg: InboundMessage) -> None:
msg = _apply_effective_owner(msg)
async with self._semaphore:
try:
if msg.msg_type == InboundMessageType.COMMAND:
@@ -877,10 +939,27 @@ class ChannelManager:
# -- chat handling -----------------------------------------------------
async def _create_thread(self, client, msg: InboundMessage) -> str:
"""Create a new thread through Gateway and store the mapping."""
thread = await client.threads.create()
thread_id = thread["thread_id"]
async def _lookup_thread_id(self, msg: InboundMessage) -> str | None:
if msg.connection_id and self._connection_repo is not None:
return await self._connection_repo.get_thread_id(
msg.connection_id,
msg.chat_id,
msg.topic_id,
)
return self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
async def _store_thread_id(self, msg: InboundMessage, thread_id: str) -> None:
if msg.connection_id and msg.owner_user_id and self._connection_repo is not None:
await self._connection_repo.set_thread_id(
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
provider=msg.channel_name,
external_conversation_id=msg.chat_id,
external_topic_id=msg.topic_id,
thread_id=thread_id,
)
return
self.store.set_thread_id(
msg.channel_name,
msg.chat_id,
@@ -888,18 +967,49 @@ class ChannelManager:
topic_id=msg.topic_id,
user_id=msg.user_id,
)
async def _create_thread(self, client, msg: InboundMessage) -> str:
"""Create a new thread through Gateway and store the mapping."""
metadata = _thread_channel_metadata(msg)
owner_headers = _owner_headers(msg)
if owner_headers:
thread = await client.threads.create(metadata=metadata, headers=owner_headers)
else:
thread = await client.threads.create(metadata=metadata)
thread_id = thread["thread_id"]
await self._store_thread_id(msg, thread_id)
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
return thread_id
async def _update_thread_channel_metadata(self, client, msg: InboundMessage, thread_id: str) -> None:
"""Best-effort source metadata backfill for existing IM-created threads."""
# The metadata (provider/chat/topic) is constant for a thread, so one
# successful backfill per manager lifetime is enough — skip the
# redundant PATCH on every subsequent inbound message.
if thread_id in self._channel_metadata_synced:
return
update_kwargs: dict[str, Any] = {"metadata": _thread_channel_metadata(msg)}
if owner_headers := _owner_headers(msg):
update_kwargs["headers"] = owner_headers
try:
await client.threads.update(thread_id, **update_kwargs)
except Exception:
logger.debug("[Manager] failed to update channel metadata for thread_id=%s", thread_id, exc_info=True)
return
if len(self._channel_metadata_synced) > 4096:
self._channel_metadata_synced.clear()
self._channel_metadata_synced.add(thread_id)
async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None:
client = self._get_client()
# Look up existing DeerFlow thread.
# topic_id may be None (e.g. Telegram private chats) — the store
# handles this by using the "channel:chat_id" key without a topic suffix.
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
thread_id = await self._lookup_thread_id(msg)
if thread_id:
logger.info("[Manager] reusing thread: thread_id=%s for topic_id=%s", thread_id, msg.topic_id)
await self._update_thread_channel_metadata(client, msg, thread_id)
# No existing thread found — create a new one
if thread_id is None:
@@ -940,14 +1050,19 @@ class ChannelManager:
return
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
run_kwargs: dict[str, Any] = {
"input": {"messages": [human_message]},
"config": run_config,
"context": run_context,
"multitask_strategy": "reject",
}
if owner_headers := _owner_headers(msg):
run_kwargs["headers"] = owner_headers
try:
result = await client.runs.wait(
thread_id,
assistant_id,
input={"messages": [human_message]},
config=run_config,
context=run_context,
multitask_strategy="reject",
**run_kwargs,
)
except Exception as exc:
if _is_thread_busy_error(exc):
@@ -984,6 +1099,8 @@ class ChannelManager:
artifacts=artifacts,
attachments=attachments,
thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
)
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
@@ -1008,16 +1125,21 @@ class ChannelManager:
last_published_text = ""
last_publish_at = 0.0
stream_error: BaseException | None = None
stream_kwargs: dict[str, Any] = {
"input": {"messages": [human_message]},
"config": run_config,
"context": run_context,
"stream_mode": ["messages-tuple", "values"],
"multitask_strategy": "reject",
}
if owner_headers := _owner_headers(msg):
stream_kwargs["headers"] = owner_headers
try:
async for chunk in client.runs.stream(
thread_id,
assistant_id,
input={"messages": [human_message]},
config=run_config,
context=run_context,
stream_mode=["messages-tuple", "values"],
multitask_strategy="reject",
**stream_kwargs,
):
event = getattr(chunk, "event", "")
data = getattr(chunk, "data", None)
@@ -1047,6 +1169,8 @@ class ChannelManager:
text=latest_text,
is_final=False,
thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_response_metadata(msg.metadata),
)
)
@@ -1093,6 +1217,8 @@ class ChannelManager:
attachments=attachments,
is_final=True,
thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
)
)
@@ -1124,18 +1250,10 @@ class ChannelManager:
if reply is None and command == "new":
# Create a new thread through Gateway
client = self._get_client()
thread = await client.threads.create()
new_thread_id = thread["thread_id"]
self.store.set_thread_id(
msg.channel_name,
msg.chat_id,
new_thread_id,
topic_id=msg.topic_id,
user_id=msg.user_id,
)
await self._create_thread(client, msg)
reply = "New conversation started."
elif reply is None and command == "status":
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
thread_id = await self._lookup_thread_id(msg)
reply = f"Active thread: {thread_id}" if thread_id else "No active conversation."
elif reply is None and command == "models":
reply = await self._fetch_gateway("/api/models", "models")
@@ -1174,9 +1292,11 @@ class ChannelManager:
outbound = OutboundMessage(
channel_name=msg.channel_name,
chat_id=msg.chat_id,
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "",
thread_id=await self._lookup_thread_id(msg) or "",
text=reply,
thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_slim_metadata(msg.metadata),
)
await self.bus.publish_outbound(outbound)
@@ -1212,9 +1332,11 @@ class ChannelManager:
outbound = OutboundMessage(
channel_name=msg.channel_name,
chat_id=msg.chat_id,
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "",
thread_id=await self._lookup_thread_id(msg) or "",
text=error_text,
thread_ts=msg.thread_ts,
connection_id=msg.connection_id,
owner_user_id=msg.owner_user_id,
metadata=_slim_metadata(msg.metadata),
)
await self.bus.publish_outbound(outbound)
+14
View File
@@ -44,6 +44,12 @@ class InboundMessage:
Messages sharing the same ``topic_id`` within a ``chat_id`` will
reuse the same DeerFlow thread. When ``None``, each message
creates a new thread (one-shot Q&A).
connection_id: Optional DeerFlow channel connection id. When present,
conversation mapping is scoped by the connection instead of the
legacy global ``channel_name:chat_id[:topic_id]`` key.
owner_user_id: DeerFlow user id that owns the channel connection.
Platform user ids stay in ``user_id``.
workspace_id: Optional external workspace/guild/team id.
files: Optional list of file attachments (platform-specific dicts).
metadata: Arbitrary extra data from the channel.
created_at: Unix timestamp when the message was created.
@@ -56,6 +62,9 @@ class InboundMessage:
msg_type: InboundMessageType = InboundMessageType.CHAT
thread_ts: str | None = None
topic_id: str | None = None
connection_id: str | None = None
owner_user_id: str | None = None
workspace_id: str | None = None
files: list[dict[str, Any]] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)
created_at: float = field(default_factory=time.time)
@@ -95,6 +104,9 @@ class OutboundMessage:
is_final: Whether this is the final message in the response stream.
thread_ts: Optional platform thread identifier for threaded replies.
metadata: Arbitrary extra data.
connection_id: Optional DeerFlow channel connection id used for
connection-specific outbound credentials.
owner_user_id: DeerFlow user id that owns the channel connection.
created_at: Unix timestamp.
"""
@@ -106,6 +118,8 @@ class OutboundMessage:
attachments: list[ResolvedAttachment] = field(default_factory=list)
is_final: bool = True
thread_ts: str | None = None
connection_id: str | None = None
owner_user_id: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
created_at: float = field(default_factory=time.time)
@@ -0,0 +1,154 @@
"""Local persistence for runtime IM channel configuration."""
from __future__ import annotations
import json
import logging
import tempfile
import threading
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
RUNTIME_CHANNEL_DISABLED_FLAG = "_runtime_disabled"
class ChannelRuntimeConfigStore:
"""JSON-backed store for channel credentials entered from the UI.
This intentionally mirrors ``ChannelStore``: local/private deployments get
durable runtime configuration without needing a public callback URL or a
config.yaml edit.
"""
def __init__(self, path: str | Path | None = None) -> None:
if path is None:
from deerflow.config.paths import get_paths
path = Path(get_paths().base_dir) / "channels" / "runtime-config.json"
self._path = Path(path)
self._path.parent.mkdir(parents=True, exist_ok=True)
self._data: dict[str, dict[str, Any]] = self._load()
self._lock = threading.Lock()
def _load(self) -> dict[str, dict[str, Any]]:
if self._path.exists():
try:
raw = json.loads(self._path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
logger.warning("Corrupt channel runtime config store at %s, starting fresh", self._path)
return {}
if isinstance(raw, dict):
return {str(name): dict(value) for name, value in raw.items() if isinstance(value, dict)}
return {}
def _save(self) -> None:
fd = tempfile.NamedTemporaryFile(
mode="w",
dir=self._path.parent,
suffix=".tmp",
delete=False,
)
try:
json.dump(self._data, fd, indent=2, ensure_ascii=False)
fd.close()
Path(fd.name).replace(self._path)
try:
self._path.chmod(0o600)
except OSError:
logger.debug("Unable to chmod channel runtime config store at %s", self._path, exc_info=True)
except BaseException:
fd.close()
Path(fd.name).unlink(missing_ok=True)
raise
def load_all(self) -> dict[str, dict[str, Any]]:
with self._lock:
return {name: dict(config) for name, config in self._data.items()}
def get_provider_config(self, provider: str) -> dict[str, Any] | None:
with self._lock:
config = self._data.get(provider)
return dict(config) if isinstance(config, dict) else None
def set_provider_config(self, provider: str, config: dict[str, Any]) -> None:
with self._lock:
self._data[provider] = dict(config)
self._save()
def set_provider_disconnected(self, provider: str) -> None:
with self._lock:
self._data[provider] = {
"enabled": False,
RUNTIME_CHANNEL_DISABLED_FLAG: True,
}
self._save()
def remove_provider_config(self, provider: str) -> bool:
with self._lock:
if provider not in self._data:
return False
del self._data[provider]
self._save()
return True
def _provider_enabled(channel_connections_config: Any, provider: str) -> bool:
provider_config = getattr(channel_connections_config, provider, None)
return bool(getattr(provider_config, "enabled", False))
def _runtime_channel_disconnected(runtime_config: dict[str, Any]) -> bool:
return runtime_config.get(RUNTIME_CHANNEL_DISABLED_FLAG) is True and runtime_config.get("enabled") is False
def merge_runtime_channel_configs(
channels_config: dict[str, Any],
channel_connections_config: Any,
*,
store: ChannelRuntimeConfigStore | None = None,
) -> None:
"""Merge persisted runtime provider config into ``channels_config`` in-place."""
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
return
runtime_store = store or ChannelRuntimeConfigStore()
for provider, runtime_config in runtime_store.load_all().items():
if not _provider_enabled(channel_connections_config, provider):
continue
if _runtime_channel_disconnected(runtime_config):
channels_config.pop(provider, None)
continue
existing = channels_config.get(provider)
merged = dict(runtime_config)
if isinstance(existing, dict):
merged.update(existing)
channels_config[provider] = merged
def apply_runtime_connection_config(
channel_connections_config: Any,
*,
store: ChannelRuntimeConfigStore | None = None,
) -> Any:
"""Apply persisted connection metadata that lives outside ``channels``.
Telegram uses a bot username for deep links; UI-entered values are stored
with the runtime channel config so local restarts keep the provider
configured.
"""
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
return channel_connections_config
runtime_store = store or ChannelRuntimeConfigStore()
telegram_runtime_config = runtime_store.get_provider_config("telegram")
bot_username = ""
if isinstance(telegram_runtime_config, dict):
bot_username = str(telegram_runtime_config.get("bot_username") or "").strip()
if not bot_username or not _provider_enabled(channel_connections_config, "telegram"):
return channel_connections_config
config = channel_connections_config.model_copy(deep=True)
config.telegram.bot_username = bot_username
return config
+145 -26
View File
@@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
import logging
import os
from typing import TYPE_CHECKING, Any
@@ -9,6 +10,7 @@ from typing import TYPE_CHECKING, Any
from app.channels.base import Channel
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
from app.channels.message_bus import MessageBus
from app.channels.runtime_config_store import merge_runtime_channel_configs
from app.channels.store import ChannelStore
logger = logging.getLogger(__name__)
@@ -42,6 +44,11 @@ _CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
def _channel_has_credentials(name: str, channel_config: dict[str, Any]) -> bool:
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
return any(not isinstance(channel_config.get(key), bool) and channel_config.get(key) is not None and str(channel_config[key]).strip() for key in cred_keys)
def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str, default: str) -> str:
value = config.pop(config_key, None)
if isinstance(value, str) and value.strip():
@@ -52,6 +59,30 @@ def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str,
return default
def _merge_channel_connection_runtime_config(channels_config: dict[str, Any], app_config: AppConfig) -> None:
connection_config = getattr(app_config, "channel_connections", None)
merge_runtime_channel_configs(channels_config, connection_config)
def _make_connection_repo(app_config: AppConfig):
connection_config = getattr(app_config, "channel_connections", None)
if connection_config is None or not getattr(connection_config, "enabled", False):
return None
try:
from deerflow.persistence.channel_connections import ChannelConnectionRepository
from deerflow.persistence.engine import get_session_factory
except Exception:
logger.exception("Failed to import channel connection repository")
return None
session_factory = get_session_factory()
if session_factory is None:
logger.warning("Channel connections are enabled but database persistence is not available")
return None
return ChannelConnectionRepository(session_factory)
class ChannelService:
"""Manages the lifecycle of all configured IM channels.
@@ -59,9 +90,10 @@ class ChannelService:
instantiates enabled channels, and starts the ChannelManager dispatcher.
"""
def __init__(self, channels_config: dict[str, Any] | None = None) -> None:
def __init__(self, channels_config: dict[str, Any] | None = None, *, connection_repo: Any | None = None) -> None:
self.bus = MessageBus()
self.store = ChannelStore()
self._connection_repo = connection_repo
config = dict(channels_config or {})
langgraph_url = _resolve_service_url(config, "langgraph_url", _CHANNELS_LANGGRAPH_URL_ENV, DEFAULT_LANGGRAPH_URL)
gateway_url = _resolve_service_url(config, "gateway_url", _CHANNELS_GATEWAY_URL_ENV, DEFAULT_GATEWAY_URL)
@@ -74,10 +106,12 @@ class ChannelService:
gateway_url=gateway_url,
default_session=default_session if isinstance(default_session, dict) else None,
channel_sessions=channel_sessions,
connection_repo=connection_repo,
)
self._channels: dict[str, Any] = {} # name -> Channel instance
self._config = config
self._running = False
self._readiness_locks: dict[str, asyncio.Lock] = {}
@classmethod
def from_app_config(cls, app_config: AppConfig | None = None) -> ChannelService:
@@ -90,8 +124,9 @@ class ChannelService:
# extra fields are allowed by AppConfig (extra="allow")
extra = app_config.model_extra or {}
if "channels" in extra:
channels_config = extra["channels"]
return cls(channels_config=channels_config)
channels_config = dict(extra["channels"] or {})
_merge_channel_connection_runtime_config(channels_config, app_config)
return cls(channels_config=channels_config, connection_repo=_make_connection_repo(app_config))
async def start(self) -> None:
"""Start the manager and all enabled channels."""
@@ -99,36 +134,83 @@ class ChannelService:
return
await self.manager.start()
self._running = True
ready_status = await self.ensure_ready_channels(attempts=2)
ready_count = sum(1 for ready in ready_status.values() if ready)
logger.info("ChannelService started with %d/%d ready channels", ready_count, len(ready_status))
async def ensure_ready_channels(self, *, attempts: int = 1) -> dict[str, bool]:
"""Start or restart enabled configured channels that are not ready."""
ready_status: dict[str, bool] = {}
for name, channel_config in self._config.items():
if not isinstance(channel_config, dict):
continue
if not channel_config.get("enabled", False):
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
has_creds = any(not isinstance(channel_config.get(k), bool) and channel_config.get(k) is not None and str(channel_config[k]).strip() for k in cred_keys)
if has_creds:
if _channel_has_credentials(name, channel_config):
logger.warning(
"Channel '%s' has credentials configured but is disabled. Set enabled: true under channels.%s in config.yaml to activate it.",
name,
name,
"A configured channel has credentials configured but is disabled. Set enabled: true under its channels entry in config.yaml to activate it.",
)
else:
logger.info("Channel %s is disabled, skipping", name)
logger.info("A configured channel is disabled, skipping")
continue
await self._start_channel(name, channel_config)
ready_status[name] = await self.ensure_channel_ready(name, attempts=attempts)
return ready_status
self._running = True
logger.info("ChannelService started with channels: %s", list(self._channels.keys()))
async def ensure_channel_ready(
self,
name: str,
config: dict[str, Any] | None = None,
*,
attempts: int = 1,
) -> bool:
"""Ensure a single enabled channel is running using its current config."""
if not self._running:
logger.warning("ChannelService is not running; cannot ensure channel readiness")
return False
if config is not None:
self._config[name] = dict(config)
# Serialize per channel: readiness is polled from request handlers, so
# concurrent calls must not stop/start the same channel worker twice.
lock = self._readiness_locks.setdefault(name, asyncio.Lock())
async with lock:
channel_config = self._config.get(name)
if not channel_config or not isinstance(channel_config, dict):
logger.warning("No config for requested channel")
return False
if not channel_config.get("enabled", False):
return False
channel = self._channels.get(name)
if channel is not None and channel.is_running:
return True
if channel is not None:
try:
await channel.stop()
except Exception:
logger.exception("Error stopping non-running channel before readiness retry")
self._channels.pop(name, None)
max_attempts = max(1, attempts)
for attempt in range(max_attempts):
if attempt > 0:
logger.info("Retrying channel startup after readiness check")
if await self._start_channel(name, channel_config):
return True
return False
async def stop(self) -> None:
"""Stop all channels and the manager."""
for name, channel in list(self._channels.items()):
try:
await channel.stop()
logger.info("Channel %s stopped", name)
logger.info("Channel stopped")
except Exception:
logger.exception("Error stopping channel %s", name)
logger.exception("Error stopping channel")
self._channels.clear()
await self.manager.stop()
@@ -140,6 +222,9 @@ class ChannelService:
Uses ``get_app_config()`` which detects file changes via mtime,
so edits to ``config.yaml`` are picked up without a process restart.
The UI runtime-config overlay applied at startup is re-applied here
so a file-driven reload neither drops credentials entered from the
browser nor resurrects a channel disconnected from it.
Falls back to the cached ``self._config`` when config loading fails.
"""
try:
@@ -147,7 +232,8 @@ class ChannelService:
app_config = get_app_config()
extra = app_config.model_extra or {}
channels_config = extra.get("channels", {})
channels_config = dict(extra.get("channels") or {})
_merge_channel_connection_runtime_config(channels_config, app_config)
channel_config = channels_config.get(name)
if isinstance(channel_config, dict):
# Update the cached config so get_status() stays consistent.
@@ -157,18 +243,23 @@ class ChannelService:
logger.exception("Failed to reload config for channel %s, using cached version", name)
return self._config.get(name)
async def restart_channel(self, name: str) -> bool:
async def restart_channel(self, name: str, *, reload_config: bool = True) -> bool:
"""Restart a specific channel. Returns True if successful."""
if name in self._channels:
try:
await self._channels[name].stop()
except Exception:
logger.exception("Error stopping channel %s for restart", name)
logger.exception("Error stopping channel for restart")
del self._channels[name]
config = self._load_channel_config(name)
if reload_config:
# Reading config.yaml and the runtime store is disk IO; keep it
# off the event loop.
config = await asyncio.to_thread(self._load_channel_config, name)
else:
config = self._config.get(name)
if not config or not isinstance(config, dict):
logger.warning("No config for channel %s", name)
logger.warning("No config for requested channel")
return False
if not config.get("enabled", False):
@@ -177,11 +268,35 @@ class ChannelService:
return await self._start_channel(name, config)
async def configure_channel(self, name: str, config: dict[str, Any]) -> bool:
"""Apply runtime config for a channel and restart it if the service is running."""
self._config[name] = dict(config)
if not self._running:
return True
# The caller just supplied the authoritative config (e.g. credentials
# entered in the browser that are never written to config.yaml) — a
# file reload here would clobber it with the stale on-disk entry.
return await self.restart_channel(name, reload_config=False)
async def remove_channel(self, name: str) -> bool:
"""Remove runtime config for a channel and stop it if currently running."""
self._config.pop(name, None)
channel = self._channels.pop(name, None)
if channel is None:
return True
try:
await channel.stop()
logger.info("Channel stopped and removed")
return True
except Exception:
logger.exception("Error stopping channel for removal")
return False
async def _start_channel(self, name: str, config: dict[str, Any]) -> bool:
"""Instantiate and start a single channel."""
import_path = _CHANNEL_REGISTRY.get(name)
if not import_path:
logger.warning("Unknown channel type: %s", name)
logger.warning("Unknown channel type")
return False
try:
@@ -189,24 +304,26 @@ class ChannelService:
channel_cls = resolve_class(import_path, base_class=None)
except Exception:
logger.exception("Failed to import channel class for %s", name)
logger.exception("Failed to import channel class")
return False
try:
config = dict(config)
config["channel_store"] = self.store
if self._connection_repo is not None:
config["connection_repo"] = self._connection_repo
channel = channel_cls(bus=self.bus, config=config)
self._channels[name] = channel
await channel.start()
if not channel.is_running:
self._channels.pop(name, None)
logger.error("Channel %s did not enter a running state after start()", name)
logger.error("Channel did not enter a running state after start()")
return False
logger.info("Channel %s started", name)
logger.info("Channel started")
return True
except Exception:
self._channels.pop(name, None)
logger.exception("Failed to start channel %s", name)
logger.exception("Failed to start channel")
return False
def get_status(self) -> dict[str, Any]:
@@ -245,7 +362,9 @@ async def start_channel_service(app_config: AppConfig | None = None) -> ChannelS
global _channel_service
if _channel_service is not None:
return _channel_service
_channel_service = ChannelService.from_app_config(app_config)
# from_app_config reads the JSON channel store and runtime config files;
# keep that disk IO off the event loop.
_channel_service = await asyncio.to_thread(ChannelService.from_app_config, app_config)
await _channel_service.start()
return _channel_service
+148 -26
View File
@@ -9,7 +9,8 @@ from typing import Any
from markdown_to_mrkdwn import SlackMarkdownConverter
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -64,6 +65,9 @@ class SlackChannel(Channel):
self._web_client = None
self._loop: asyncio.AbstractEventLoop | None = None
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
self._connection_repo = config.get("connection_repo")
self._web_client_factory = config.get("web_client_factory")
self._connection_web_clients: dict[str, tuple[str, Any]] = {}
configured_bot_user_id = config.get("bot_user_id")
self._bot_user_id = str(configured_bot_user_id).lstrip("@") if configured_bot_user_id else None
@@ -80,26 +84,28 @@ class SlackChannel(Channel):
return
self._SocketModeResponse = SocketModeResponse
if self._web_client_factory is None:
self._web_client_factory = WebClient
bot_token = self.config.get("bot_token", "")
app_token = self.config.get("app_token", "")
if self._connection_repo is not None and self.config.get("event_delivery") == "http":
if not bot_token:
logger.error("Slack HTTP Events mode requires bot_token")
return
await self._initialize_operator_web_client(str(bot_token))
self._loop = asyncio.get_event_loop()
self._running = True
self.bus.subscribe_outbound(self._on_outbound)
logger.info("Slack channel started in HTTP Events mode")
return
if not bot_token or not app_token:
logger.error("Slack channel requires bot_token and app_token")
return
self._web_client = WebClient(token=bot_token)
if self._bot_user_id is None:
try:
auth_info = await asyncio.to_thread(self._web_client.auth_test)
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
if user_id is None:
auth_get = getattr(auth_info, "get", None)
user_id = auth_get("user_id") if callable(auth_get) else None
if isinstance(user_id, str) and user_id:
self._bot_user_id = user_id
except Exception:
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
await self._initialize_operator_web_client(str(bot_token))
self._socket_client = SocketModeClient(
app_token=app_token,
web_client=self._web_client,
@@ -124,7 +130,8 @@ class SlackChannel(Channel):
logger.info("Slack channel stopped")
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
if not self._web_client:
web_client = await self._get_web_client_for_message(msg)
if not web_client:
return
kwargs: dict[str, Any] = {
@@ -137,11 +144,12 @@ class SlackChannel(Channel):
last_exc: Exception | None = None
for attempt in range(_max_retries):
try:
await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs)
await asyncio.to_thread(web_client.chat_postMessage, **kwargs)
# Add a completion reaction to the thread root
if msg.thread_ts:
await asyncio.to_thread(
self._add_reaction,
self._add_reaction_with_client,
web_client,
msg.chat_id,
msg.thread_ts,
"white_check_mark",
@@ -165,7 +173,8 @@ class SlackChannel(Channel):
if msg.thread_ts:
try:
await asyncio.to_thread(
self._add_reaction,
self._add_reaction_with_client,
web_client,
msg.chat_id,
msg.thread_ts,
"x",
@@ -177,7 +186,8 @@ class SlackChannel(Channel):
raise last_exc
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
if not self._web_client:
web_client = await self._get_web_client_for_message(msg)
if not web_client:
return False
try:
@@ -190,7 +200,7 @@ class SlackChannel(Channel):
if msg.thread_ts:
kwargs["thread_ts"] = msg.thread_ts
await asyncio.to_thread(self._web_client.files_upload_v2, **kwargs)
await asyncio.to_thread(web_client.files_upload_v2, **kwargs)
logger.info("[Slack] file uploaded: %s to channel=%s", attachment.filename, msg.chat_id)
return True
except Exception:
@@ -199,12 +209,45 @@ class SlackChannel(Channel):
# -- internal ----------------------------------------------------------
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
if not self._web_client:
async def _initialize_operator_web_client(self, bot_token: str) -> None:
self._web_client = self._web_client_factory(token=bot_token)
if self._bot_user_id is not None:
return
try:
self._web_client.reactions_add(
auth_info = await asyncio.to_thread(self._web_client.auth_test)
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
if user_id is None:
auth_get = getattr(auth_info, "get", None)
user_id = auth_get("user_id") if callable(auth_get) else None
if isinstance(user_id, str) and user_id:
self._bot_user_id = user_id
except Exception:
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
async def _get_web_client_for_message(self, msg: OutboundMessage):
if msg.connection_id and self._connection_repo is not None:
credentials = await self._connection_repo.get_credentials(msg.connection_id)
access_token = credentials.get("access_token") if credentials else None
if not access_token:
return self._web_client
# WebClient keeps its own HTTP session and rate-limit state, so
# reuse one per connection until its token changes.
cached = self._connection_web_clients.get(msg.connection_id)
if cached is not None and cached[0] == access_token:
return cached[1]
if self._web_client_factory is None:
from slack_sdk import WebClient
self._web_client_factory = WebClient
web_client = self._web_client_factory(token=access_token)
self._connection_web_clients[msg.connection_id] = (access_token, web_client)
return web_client
return self._web_client
@staticmethod
def _add_reaction_with_client(web_client, channel_id: str, timestamp: str, emoji: str) -> None:
try:
web_client.reactions_add(
channel=channel_id,
timestamp=timestamp,
name=emoji,
@@ -213,6 +256,12 @@ class SlackChannel(Channel):
if "already_reacted" not in str(exc):
logger.warning("[Slack] failed to add reaction %s: %s", emoji, exc)
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
if not self._web_client:
return
self._add_reaction_with_client(self._web_client, channel_id, timestamp, emoji)
def _send_running_reply(self, channel_id: str, thread_ts: str) -> None:
"""Send a 'Working on it......' reply in the thread (called from SDK thread)."""
if not self._web_client:
@@ -249,12 +298,15 @@ class SlackChannel(Channel):
# Handle message events (DM or @mention)
if etype in ("message", "app_mention"):
self._handle_message_event(event)
self._handle_message_event(
event,
team_id=req.payload.get("team_id") or req.payload.get("team") or event.get("team"),
)
except Exception:
logger.exception("Error processing Slack event")
def _handle_message_event(self, event: dict) -> None:
def _handle_message_event(self, event: dict, *, team_id: str | None = None) -> None:
# Ignore bot messages
if event.get("bot_id") or event.get("subtype"):
return
@@ -272,6 +324,19 @@ class SlackChannel(Channel):
if not text:
return
connect_code = extract_connect_code(text)
if connect_code:
if self._loop and self._loop.is_running():
asyncio.run_coroutine_threadsafe(
self._bind_connection_from_connect_code(
event=event,
team_id=str(team_id or event.get("team") or ""),
code=connect_code,
),
self._loop,
)
return
channel_id = event.get("channel", "")
thread_ts = event.get("thread_ts") or event.get("ts", "")
@@ -297,4 +362,61 @@ class SlackChannel(Channel):
self._add_reaction(channel_id, event.get("ts", thread_ts), "eyes")
# Send "running" reply first (fire-and-forget from SDK thread)
self._send_running_reply(channel_id, thread_ts)
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
if self._connection_repo is None:
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
else:
asyncio.run_coroutine_threadsafe(self._publish_inbound_with_connection(inbound, team_id=team_id), self._loop)
async def _publish_inbound_with_connection(self, inbound, *, team_id: str | None = None) -> None:
inbound = await self._attach_connection_identity(inbound, team_id=team_id)
await self.bus.publish_inbound(inbound)
async def _attach_connection_identity(self, inbound, *, team_id: str | None = None):
workspace_id = str(team_id or inbound.metadata.get("team_id") or "")
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="slack",
workspace_id=workspace_id,
)
async def _bind_connection_from_connect_code(self, *, event: dict, team_id: str, code: str) -> bool:
if self._connection_repo is None or not code:
return False
channel_id = str(event.get("channel") or "")
thread_ts = str(event.get("thread_ts") or event.get("ts") or "")
state = await self._connection_repo.consume_oauth_state(provider="slack", state=code)
if state is None:
await self._post_connection_reply(channel_id, "Slack connection code is invalid or expired.", thread_ts)
return True
user_id = str(event.get("user") or "")
if not user_id or not team_id:
await self._post_connection_reply(channel_id, "Slack connection could not be completed from this message.", thread_ts)
return True
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="slack",
external_account_id=user_id,
workspace_id=team_id,
metadata={
"team_id": team_id,
"channel_id": channel_id,
},
status="connected",
)
await self._post_connection_reply(channel_id, "Slack connected to DeerFlow.", thread_ts)
return True
async def _post_connection_reply(self, channel_id: str, text: str, thread_ts: str | None = None) -> None:
if not self._web_client or not channel_id:
return
kwargs: dict[str, Any] = {"channel": channel_id, "text": text}
if thread_ts:
kwargs["thread_ts"] = thread_ts
try:
await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs)
except Exception:
logger.exception("[Slack] failed to send connection reply in channel=%s", channel_id)
+57
View File
@@ -8,6 +8,7 @@ import threading
from typing import Any
from app.channels.base import Channel
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -35,6 +36,7 @@ class TelegramChannel(Channel):
pass
# chat_id -> last sent message_id for threaded replies
self._last_bot_message: dict[str, int] = {}
self._connection_repo = config.get("connection_repo")
async def start(self) -> None:
if self._running:
@@ -233,6 +235,54 @@ class TelegramChannel(Channel):
return True
return user_id in self._allowed_users
@staticmethod
def _telegram_display_name(user) -> str:
full_name = getattr(user, "full_name", None)
if isinstance(full_name, str) and full_name:
return full_name
username = getattr(user, "username", None)
if isinstance(username, str) and username:
return username
return str(getattr(user, "id", ""))
async def _bind_connection_from_start_token(self, update, state_token: str) -> bool:
if self._connection_repo is None or not state_token:
return False
state = await self._connection_repo.consume_oauth_state(provider="telegram", state=state_token)
if state is None:
await update.message.reply_text("Telegram connection link is invalid or expired.")
return True
owner_user_id = state["owner_user_id"]
user_id = str(update.effective_user.id)
chat_id = str(update.effective_chat.id)
connection = await self._connection_repo.upsert_connection(
owner_user_id=owner_user_id,
provider="telegram",
external_account_id=user_id,
external_account_name=self._telegram_display_name(update.effective_user),
workspace_id=chat_id,
workspace_name=None,
metadata={
"chat_id": chat_id,
"chat_type": update.effective_chat.type,
"telegram_username": getattr(update.effective_user, "username", None),
},
status="connected",
)
logger.info("[Telegram] bound chat=%s user=%s to DeerFlow user=%s connection=%s", chat_id, user_id, owner_user_id, connection["id"])
await update.message.reply_text("Telegram connected to DeerFlow.")
return True
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="telegram",
workspace_id=inbound.chat_id,
)
def _get_bot_username(self, context) -> str | None:
bot = getattr(context, "bot", None)
username = getattr(bot, "username", None)
@@ -264,6 +314,11 @@ class TelegramChannel(Channel):
"""Handle /start command."""
if not self._check_user(update.effective_user.id):
return
args = getattr(context, "args", []) if context is not None else []
if args:
handled = await self._bind_connection_from_start_token(update, str(args[0]))
if handled:
return
await update.message.reply_text("Welcome to DeerFlow! Send me a message to start a conversation.\nType /help for available commands.")
async def _process_incoming_with_reply(self, chat_id: str, msg_id: int, inbound: InboundMessage) -> None:
@@ -299,6 +354,7 @@ class TelegramChannel(Channel):
thread_ts=msg_id,
)
inbound.topic_id = topic_id
inbound = await self._attach_connection_identity(inbound)
if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
@@ -341,6 +397,7 @@ class TelegramChannel(Channel):
thread_ts=msg_id,
)
inbound.topic_id = topic_id
inbound = await self._attach_connection_identity(inbound)
if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
+60 -2
View File
@@ -22,8 +22,9 @@ from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -253,6 +254,7 @@ class WechatChannel(Channel):
self._state_dir = self._resolve_state_dir(config.get("state_dir"))
self._cursor_path = self._state_dir / "wechat-getupdates.json" if self._state_dir else None
self._auth_path = self._state_dir / "wechat-auth.json" if self._state_dir else None
self._connection_repo = config.get("connection_repo")
self._load_state()
async def start(self) -> None:
@@ -617,6 +619,16 @@ class WechatChannel(Channel):
if thread_ts:
self._context_tokens_by_thread[thread_ts] = context_token
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
handled = await self._bind_connection_from_connect_code(
chat_id=chat_id,
context_token=context_token,
code=connect_code,
)
if handled:
return
inbound = self._make_inbound(
chat_id=chat_id,
user_id=chat_id,
@@ -632,8 +644,54 @@ class WechatChannel(Channel):
},
)
inbound.topic_id = None
inbound = await self._attach_connection_identity(inbound)
await self.bus.publish_inbound(inbound)
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="wechat",
workspace_id=inbound.chat_id,
)
async def _bind_connection_from_connect_code(self, *, chat_id: str, context_token: str, code: str) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="wechat", state=code)
if state is None:
await self._send_connection_reply(chat_id, context_token, "WeChat connection code is invalid or expired.")
return True
if not chat_id:
await self._send_connection_reply(chat_id, context_token, "WeChat connection could not be completed from this message.")
return True
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="wechat",
external_account_id=chat_id,
workspace_id=chat_id,
metadata={
"context_token": context_token,
},
status="connected",
)
await self._send_connection_reply(chat_id, context_token, "WeChat connected to DeerFlow.")
return True
async def _send_connection_reply(self, chat_id: str, context_token: str, text: str) -> None:
if not context_token:
return
await self._send_text_message(
chat_id=chat_id,
context_token=context_token,
text=text,
client_id_prefix="deerflow-connect",
max_retries=1,
)
async def _ensure_authenticated(self) -> bool:
async with self._auth_lock:
if self._bot_token:
+58 -1
View File
@@ -8,8 +8,10 @@ from collections.abc import Awaitable, Callable
from typing import Any, cast
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import (
InboundMessage,
InboundMessageType,
MessageBus,
OutboundMessage,
@@ -29,6 +31,7 @@ class WeComChannel(Channel):
self._ws_frames: dict[str, dict[str, Any]] = {}
self._ws_stream_ids: dict[str, str] = {}
self._working_message = "Working on it..."
self._connection_repo = config.get("connection_repo")
@property
def supports_streaming(self) -> bool:
@@ -271,6 +274,16 @@ class WeComChannel(Channel):
user_id = (body.get("from") or {}).get("userid")
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
handled = await self._bind_connection_from_connect_code(
frame=frame,
user_id=str(user_id or ""),
code=connect_code,
)
if handled:
return
inbound_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
inbound = self._make_inbound(
chat_id=user_id, # keep user's conversation in memory
@@ -292,8 +305,52 @@ class WeComChannel(Channel):
except Exception:
pass
inbound = await self._attach_connection_identity(inbound)
await self.bus.publish_inbound(inbound)
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="wecom",
workspace_id=str(inbound.metadata.get("aibotid") or "") or None,
fallback_without_workspace=True,
)
async def _bind_connection_from_connect_code(self, *, frame: dict[str, Any], user_id: str, code: str) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="wecom", state=code)
if state is None:
await self._send_connection_reply(frame, "WeCom connection code is invalid or expired.")
return True
if not user_id:
await self._send_connection_reply(frame, "WeCom connection could not be completed from this message.")
return True
body = frame.get("body", {}) or {}
workspace_id = str(body.get("aibotid") or "") or None
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="wecom",
external_account_id=user_id,
workspace_id=workspace_id,
metadata={
"aibotid": workspace_id,
"chattype": body.get("chattype"),
},
status="connected",
)
await self._send_connection_reply(frame, "WeCom connected to DeerFlow.")
return True
async def _send_connection_reply(self, frame: dict[str, Any], text: str) -> None:
if not self._ws_client:
return
await self._ws_client.reply(frame, {"msgtype": "text", "text": {"content": text}})
async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
if not self._ws_client:
return