mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-13 10:55:59 +00:00
feat(im): Add user-owned IM channel connections (#3487)
* Add user-owned IM channel connections * Fix dev startup and channel connect popup * Use async channel connect flow * Harden dev service daemon startup * Support local IM channel connections * Align IM connections with local channels * Fix safe user id digest algorithm * Address Copilot IM channel feedback * Address IM channel review comments * Support all integrated IM channel connections * Format additional channel connection tests * Keep unavailable channel connect buttons clickable * Fix IM channel provider icons * Add runtime setup for enabled IM channels * Guard global shortcut key handling * Keep configured IM channels editable * Avoid password autofill for channel secrets * Make channel threads visible to connection owners * Persist IM runtime config locally * Allow disconnecting runtime IM channels * Route no-auth channel sessions to local user * Use default user for auth-disabled local mode * Show IM channel source on threads * Prefill IM channel runtime config * Reflect IM channel runtime health * Ignore Feishu message read events * Ignore Feishu non-content message events * Let setup wizard enable IM channels * Fix frontend formatting after merge * Stabilize backend tests without local config * Isolate channel runtime config tests * Address channel connection review comments * Use sha256 user buckets with legacy migration * Ensure runtime IM channels are ready after restart * Persist disconnected IM channel state * Address channel connection review comments * Address channel connection review findings Frontend connect flow: - Open the runtime-config dialog only when a provider still needs credentials; configured providers go straight to the connect flow, so the binding-code/deep-link path is reachable from the UI again. - After saving credentials, continue into the connect flow when a user binding is still required (multi-user mode) instead of stopping at a "Connected" toast. - Extract shared provider-state helpers to core/channels/provider-state and add unit + e2e coverage for the direct-connect and configure-then-connect paths. Provider status semantics: - Report connection_status from the user's newest connection row; with no binding it is not_connected, except in auth-disabled local mode where a configured running channel is effectively connected. Concurrency and event-loop correctness: - Offload ChannelRuntimeConfigStore construction and writes, channel service construction, and Slack connection replies to threads; add a tests/blocking_io/ anchor for the runtime-config handlers. - Consume binding codes with a conditional UPDATE so a code can only be used once under concurrent workers; retry upsert_connection as an update when a concurrent insert wins the unique constraint. - Serialize ensure_channel_ready per channel so concurrent provider polls cannot double-start a channel worker. Config and migration hardening: - Stop mutating the get_app_config()-cached Telegram provider config; the runtime store now owns the UI-entered bot username. - Register channel_connections in STARTUP_ONLY_FIELDS with the standardized startup-only Field description. - Match the legacy unsafe-id bucket by recomputing its exact SHA-1 name so another user's same-prefix bucket can never be migrated. - Remove the unused Telegram process_webhook_update path and document src/core/channels in the frontend docs. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> * Address PR review comments on authz scoping and channel runtime Security (review feedback from ShenAC-SAC): - Scope internal-token callers to the connection owner carried in X-DeerFlow-Owner-User-Id instead of bypassing owner checks outright, in both require_permission(owner_check=True) and the stateless run endpoints. Internal callers keep access to their own and shared/legacy threads, and may claim a default-owned channel thread for its real owner, but a leaked internal token no longer grants cross-user thread access. - Require admin privileges for POST/DELETE /api/channels/{provider}/ runtime-config: runtime credentials and channel workers are instance-wide shared state (same model as the MCP config API). Read-only provider listing stays available to all users. Performance (review feedback from willem-bd): - Skip the redundant thread channel-metadata PATCH after the first successful backfill per thread. - Reuse the per-connection Slack WebClient until its token changes instead of constructing one per outbound message. - Reconcile channel readiness for all providers concurrently in GET /api/channels/providers. Also resolve the code-quality unused-import flag in the blocking-io anchor by pre-importing the channel service via importlib. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> * Fix prettier formatting in provider-state test Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> * Reconcile UI runtime channel config with config reload on restart Main now reloads a channel's config.yaml entry on restart_channel() (#3514, issue #3497). Adapt the user-owned connection flow to coexist: - configure_channel() restarts with reload_config=False — the caller just supplied the authoritative config (browser-entered credentials that are never written to config.yaml), so a file reload must not clobber it with the stale on-disk entry. - _load_channel_config() re-applies the UI runtime-store overlay used at startup, so an operator-triggered restart keeps browser-entered credentials for channels without a config.yaml entry and does not resurrect a channel disconnected from the UI. - Offload the reload's disk IO (config.yaml + runtime store) with asyncio.to_thread, matching the blocking-IO policy on this branch. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> --------- Co-authored-by: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -20,6 +20,17 @@ KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset(
|
||||
)
|
||||
|
||||
|
||||
def extract_connect_code(text: str) -> str | None:
|
||||
"""Extract the one-time channel binding code from a connect command."""
|
||||
parts = text.strip().split()
|
||||
if len(parts) < 2:
|
||||
return None
|
||||
command = parts[0].lower()
|
||||
if command in {"/connect", "connect"}:
|
||||
return parts[1]
|
||||
return None
|
||||
|
||||
|
||||
def is_known_channel_command(text: str) -> bool:
|
||||
"""Return whether text starts with a registered channel control command."""
|
||||
if not text.startswith("/"):
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
"""Helpers for attaching persisted channel connection ownership to inbound messages."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.channels.message_bus import InboundMessage
|
||||
|
||||
|
||||
async def attach_connection_identity(
|
||||
inbound: InboundMessage,
|
||||
*,
|
||||
repo: Any,
|
||||
provider: str,
|
||||
workspace_id: str | None,
|
||||
fallback_without_workspace: bool = False,
|
||||
) -> InboundMessage:
|
||||
"""Attach connection metadata to an inbound message when a persisted binding exists."""
|
||||
if repo is None:
|
||||
return inbound
|
||||
|
||||
workspace_candidates: list[str | None] = []
|
||||
if workspace_id:
|
||||
workspace_candidates.append(workspace_id)
|
||||
if fallback_without_workspace:
|
||||
workspace_candidates.append(None)
|
||||
if not workspace_candidates:
|
||||
return inbound
|
||||
|
||||
for candidate in workspace_candidates:
|
||||
connection = await repo.find_connection_by_external_identity(
|
||||
provider=provider,
|
||||
external_account_id=inbound.user_id,
|
||||
workspace_id=candidate,
|
||||
)
|
||||
if connection is None:
|
||||
continue
|
||||
|
||||
inbound.connection_id = connection["id"]
|
||||
inbound.owner_user_id = connection["owner_user_id"]
|
||||
inbound.workspace_id = connection.get("workspace_id")
|
||||
return inbound
|
||||
|
||||
return inbound
|
||||
@@ -14,7 +14,8 @@ from typing import Any
|
||||
import httpx
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -136,6 +137,7 @@ class DingTalkChannel(Channel):
|
||||
self._incoming_messages: dict[str, Any] = {}
|
||||
self._incoming_messages_lock = threading.Lock()
|
||||
self._card_repliers: dict[str, Any] = {}
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
@@ -395,6 +397,24 @@ class DingTalkChannel(Channel):
|
||||
text[:100],
|
||||
)
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and self._connection_repo is not None:
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
self._bind_connection_from_connect_code(
|
||||
conversation_type=conversation_type,
|
||||
sender_staff_id=sender_staff_id,
|
||||
sender_nick=sender_nick,
|
||||
conversation_id=conversation_id,
|
||||
code=connect_code,
|
||||
),
|
||||
self._main_loop,
|
||||
)
|
||||
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
|
||||
else:
|
||||
logger.warning("[DingTalk] main loop not running, cannot bind channel connection")
|
||||
return
|
||||
|
||||
if _is_dingtalk_command(text):
|
||||
msg_type = InboundMessageType.COMMAND
|
||||
else:
|
||||
@@ -450,11 +470,95 @@ class DingTalkChannel(Channel):
|
||||
return ""
|
||||
|
||||
async def _prepare_inbound(self, chat_id: str, inbound: InboundMessage) -> None:
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
# Running reply must finish before publish_inbound so AI card tracks are
|
||||
# registered before the manager emits streaming outbounds.
|
||||
await self._send_running_reply(chat_id, inbound)
|
||||
await self.bus.publish_inbound(inbound)
|
||||
|
||||
@staticmethod
|
||||
def _connection_workspace_id(conversation_type: str, conversation_id: str) -> str | None:
|
||||
if conversation_type == _CONVERSATION_TYPE_GROUP and conversation_id:
|
||||
return conversation_id
|
||||
return None
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||
conversation_type = str(inbound.metadata.get("conversation_type") or _CONVERSATION_TYPE_P2P)
|
||||
conversation_id = str(inbound.metadata.get("conversation_id") or "")
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="dingtalk",
|
||||
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
|
||||
fallback_without_workspace=True,
|
||||
)
|
||||
|
||||
async def _bind_connection_from_connect_code(
|
||||
self,
|
||||
*,
|
||||
conversation_type: str,
|
||||
sender_staff_id: str,
|
||||
sender_nick: str,
|
||||
conversation_id: str,
|
||||
code: str,
|
||||
) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="dingtalk", state=code)
|
||||
if state is None:
|
||||
await self._send_connection_reply(
|
||||
conversation_type,
|
||||
sender_staff_id,
|
||||
conversation_id,
|
||||
"DingTalk connection code is invalid or expired.",
|
||||
)
|
||||
return True
|
||||
|
||||
if not sender_staff_id:
|
||||
await self._send_connection_reply(
|
||||
conversation_type,
|
||||
sender_staff_id,
|
||||
conversation_id,
|
||||
"DingTalk connection could not be completed from this message.",
|
||||
)
|
||||
return True
|
||||
|
||||
await self._connection_repo.upsert_connection(
|
||||
owner_user_id=state["owner_user_id"],
|
||||
provider="dingtalk",
|
||||
external_account_id=sender_staff_id,
|
||||
external_account_name=sender_nick or None,
|
||||
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
|
||||
metadata={
|
||||
"conversation_type": conversation_type,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
await self._send_connection_reply(
|
||||
conversation_type,
|
||||
sender_staff_id,
|
||||
conversation_id,
|
||||
"DingTalk connected to DeerFlow.",
|
||||
)
|
||||
return True
|
||||
|
||||
async def _send_connection_reply(
|
||||
self,
|
||||
conversation_type: str,
|
||||
sender_staff_id: str,
|
||||
conversation_id: str,
|
||||
text: str,
|
||||
) -> None:
|
||||
robot_code = self._client_id
|
||||
if conversation_type == _CONVERSATION_TYPE_GROUP:
|
||||
if conversation_id:
|
||||
await self._send_text_message_to_group(robot_code, conversation_id, text)
|
||||
return
|
||||
if sender_staff_id:
|
||||
await self._send_text_message_to_user(robot_code, sender_staff_id, text)
|
||||
|
||||
async def _send_running_reply(self, chat_id: str, inbound: InboundMessage) -> None:
|
||||
conversation_type = inbound.metadata.get("conversation_type", _CONVERSATION_TYPE_P2P)
|
||||
sender_staff_id = inbound.metadata.get("sender_staff_id", "")
|
||||
|
||||
@@ -10,8 +10,9 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -70,6 +71,7 @@ class DiscordChannel(Channel):
|
||||
self._discord_loop: asyncio.AbstractEventLoop | None = None
|
||||
self._main_loop: asyncio.AbstractEventLoop | None = None
|
||||
self._discord_module = None
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
@@ -287,6 +289,10 @@ class DiscordChannel(Channel):
|
||||
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
|
||||
# Don't return early if text is empty — still process the mention (e.g., create thread)
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and await self._bind_connection_from_connect_code(message, connect_code):
|
||||
return
|
||||
|
||||
# --- Determine thread/channel routing and typing target ---
|
||||
thread_id = None
|
||||
chat_id = None
|
||||
@@ -315,6 +321,7 @@ class DiscordChannel(Channel):
|
||||
},
|
||||
)
|
||||
inbound.topic_id = thread_id
|
||||
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
|
||||
self._publish(inbound)
|
||||
# Start typing indicator in the thread
|
||||
if typing_target:
|
||||
@@ -422,6 +429,7 @@ class DiscordChannel(Channel):
|
||||
},
|
||||
)
|
||||
inbound.topic_id = thread_id
|
||||
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
|
||||
|
||||
# Start typing indicator in the correct target (thread or channel)
|
||||
if typing_target:
|
||||
@@ -436,6 +444,60 @@ class DiscordChannel(Channel):
|
||||
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
|
||||
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage, guild_id: str | None = None) -> InboundMessage:
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="discord",
|
||||
workspace_id=guild_id,
|
||||
fallback_without_workspace=True,
|
||||
)
|
||||
|
||||
async def _bind_connection_from_connect_code(self, message, code: str) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="discord", state=code)
|
||||
if state is None:
|
||||
await self._send_connection_reply(message, "Discord connection code is invalid or expired.")
|
||||
return True
|
||||
|
||||
guild = getattr(message, "guild", None)
|
||||
channel = getattr(message, "channel", None)
|
||||
author = getattr(message, "author", None)
|
||||
user_id = str(getattr(author, "id", "") or "")
|
||||
if not user_id:
|
||||
await self._send_connection_reply(message, "Discord connection could not be completed from this message.")
|
||||
return True
|
||||
|
||||
guild_id = str(getattr(guild, "id", "") or "") or None
|
||||
await self._connection_repo.upsert_connection(
|
||||
owner_user_id=state["owner_user_id"],
|
||||
provider="discord",
|
||||
external_account_id=user_id,
|
||||
external_account_name=getattr(author, "display_name", None) or getattr(author, "name", None),
|
||||
workspace_id=guild_id,
|
||||
workspace_name=getattr(guild, "name", None) if guild is not None else None,
|
||||
metadata={
|
||||
"guild_id": guild_id,
|
||||
"channel_id": str(getattr(channel, "id", "") or ""),
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
await self._send_connection_reply(message, "Discord connected to DeerFlow.")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def _send_connection_reply(message, text: str) -> None:
|
||||
channel = getattr(message, "channel", None)
|
||||
send = getattr(channel, "send", None)
|
||||
if send is None:
|
||||
return
|
||||
try:
|
||||
await send(text)
|
||||
except Exception:
|
||||
logger.exception("[Discord] failed to send connection reply")
|
||||
|
||||
def _run_client(self) -> None:
|
||||
self._discord_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._discord_loop)
|
||||
|
||||
@@ -11,7 +11,8 @@ import time
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import (
|
||||
PENDING_CLARIFICATION_METADATA_KEY,
|
||||
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY,
|
||||
@@ -71,6 +72,7 @@ class FeishuChannel(Channel):
|
||||
self._CreateImageRequestBody = None
|
||||
self._GetMessageResourceRequest = None
|
||||
self._thread_lock = threading.Lock()
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
@staticmethod
|
||||
def _non_empty_str(value: Any) -> str | None:
|
||||
@@ -86,6 +88,23 @@ class FeishuChannel(Channel):
|
||||
def supports_streaming(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
if not self._running:
|
||||
return False
|
||||
return self._thread is not None and self._thread.is_alive()
|
||||
|
||||
def _build_event_handler(self, lark):
|
||||
return (
|
||||
lark.EventDispatcherHandler.builder("", "")
|
||||
.register_p2_im_message_receive_v1(self._on_message)
|
||||
.register_p2_im_message_message_read_v1(self._on_ignored_message_event)
|
||||
.register_p2_im_message_reaction_created_v1(self._on_ignored_message_event)
|
||||
.register_p2_im_message_reaction_deleted_v1(self._on_ignored_message_event)
|
||||
.register_p2_im_message_recalled_v1(self._on_ignored_message_event)
|
||||
.build()
|
||||
)
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
return
|
||||
@@ -179,7 +198,7 @@ class FeishuChannel(Channel):
|
||||
# thread's uvloop.
|
||||
_ws_client_mod.loop = loop
|
||||
|
||||
event_handler = lark.EventDispatcherHandler.builder("", "").register_p2_im_message_receive_v1(self._on_message).build()
|
||||
event_handler = self._build_event_handler(lark)
|
||||
ws_client = lark.ws.Client(
|
||||
app_id=app_id,
|
||||
app_secret=app_secret,
|
||||
@@ -191,6 +210,10 @@ class FeishuChannel(Channel):
|
||||
except Exception:
|
||||
if self._running:
|
||||
logger.exception("Feishu WebSocket error")
|
||||
self._running = False
|
||||
|
||||
def _on_ignored_message_event(self, event) -> None:
|
||||
logger.debug("[Feishu] ignoring non-content message event: %s", type(event).__name__)
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
@@ -726,11 +749,47 @@ class FeishuChannel(Channel):
|
||||
|
||||
async def _prepare_inbound(self, msg_id: str, inbound) -> None:
|
||||
"""Kick off Feishu side effects without delaying inbound dispatch."""
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
reaction_task = asyncio.create_task(self._add_reaction(msg_id, "OK"))
|
||||
self._track_background_task(reaction_task, name="add_reaction", msg_id=msg_id)
|
||||
self._ensure_running_card_started(msg_id)
|
||||
await self.bus.publish_inbound(inbound)
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="feishu",
|
||||
workspace_id=inbound.chat_id,
|
||||
)
|
||||
|
||||
async def _bind_connection_from_connect_code(self, *, message_id: str, chat_id: str, user_id: str, code: str) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="feishu", state=code)
|
||||
if state is None:
|
||||
await self._reply_card(message_id, "Feishu connection code is invalid or expired.")
|
||||
return True
|
||||
|
||||
if not user_id or not chat_id:
|
||||
await self._reply_card(message_id, "Feishu connection could not be completed from this message.")
|
||||
return True
|
||||
|
||||
await self._connection_repo.upsert_connection(
|
||||
owner_user_id=state["owner_user_id"],
|
||||
provider="feishu",
|
||||
external_account_id=user_id,
|
||||
workspace_id=chat_id,
|
||||
metadata={
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
await self._reply_card(message_id, "Feishu connected to DeerFlow.")
|
||||
return True
|
||||
|
||||
def _on_message(self, event) -> None:
|
||||
"""Called by lark-oapi when a message is received (runs in lark thread)."""
|
||||
try:
|
||||
@@ -819,6 +878,23 @@ class FeishuChannel(Channel):
|
||||
logger.info("[Feishu] empty text, ignoring message")
|
||||
return
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and self._connection_repo is not None:
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
self._bind_connection_from_connect_code(
|
||||
message_id=msg_id,
|
||||
chat_id=chat_id,
|
||||
user_id=sender_id,
|
||||
code=connect_code,
|
||||
),
|
||||
self._main_loop,
|
||||
)
|
||||
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
|
||||
else:
|
||||
logger.warning("[Feishu] main loop not running, cannot bind channel connection")
|
||||
return
|
||||
|
||||
# Only treat known slash commands as commands; absolute paths and
|
||||
# other slash-prefixed text should be handled as normal chat.
|
||||
if _is_feishu_command(text):
|
||||
|
||||
+152
-30
@@ -274,6 +274,22 @@ def _response_metadata(base_metadata: dict[str, Any], *, pending_clarification:
|
||||
return metadata
|
||||
|
||||
|
||||
def _thread_channel_metadata(msg: InboundMessage) -> dict[str, Any]:
|
||||
channel_source: dict[str, Any] = {
|
||||
"type": "im_channel",
|
||||
"provider": msg.channel_name,
|
||||
"chat_id": msg.chat_id,
|
||||
}
|
||||
if msg.topic_id:
|
||||
channel_source["topic_id"] = msg.topic_id
|
||||
if msg.thread_ts:
|
||||
channel_source["thread_ts"] = msg.thread_ts
|
||||
if msg.connection_id:
|
||||
channel_source["connection_id"] = msg.connection_id
|
||||
|
||||
return {"channel_source": channel_source}
|
||||
|
||||
|
||||
def _extract_text_content(content: Any) -> str:
|
||||
"""Extract text from a streaming payload content field."""
|
||||
if isinstance(content, str):
|
||||
@@ -440,6 +456,43 @@ def _human_input_message(content: str, *, original_content: str | None = None) -
|
||||
return message
|
||||
|
||||
|
||||
def _auth_disabled_owner_user_id() -> str | None:
|
||||
try:
|
||||
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled
|
||||
except Exception:
|
||||
logger.debug("Unable to inspect auth-disabled mode for channel owner fallback", exc_info=True)
|
||||
return None
|
||||
return AUTH_DISABLED_USER_ID if is_auth_disabled() else None
|
||||
|
||||
|
||||
def _effective_owner_user_id(msg: InboundMessage) -> str | None:
|
||||
return _auth_disabled_owner_user_id() or msg.owner_user_id
|
||||
|
||||
|
||||
def _apply_effective_owner(msg: InboundMessage) -> InboundMessage:
|
||||
owner_user_id = _effective_owner_user_id(msg)
|
||||
if owner_user_id:
|
||||
msg.owner_user_id = owner_user_id
|
||||
return msg
|
||||
|
||||
|
||||
def _owner_headers(msg: InboundMessage) -> dict[str, str] | None:
|
||||
owner_user_id = _effective_owner_user_id(msg)
|
||||
if not owner_user_id:
|
||||
return None
|
||||
return create_internal_auth_headers(owner_user_id=owner_user_id)
|
||||
|
||||
|
||||
def _safe_user_id_for_run(raw_user_id: str) -> str:
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
try:
|
||||
return get_paths().prepare_user_dir_for_raw_id(raw_user_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to prepare channel run user directory")
|
||||
return make_safe_user_id(raw_user_id)
|
||||
|
||||
|
||||
def _resolve_slash_skill_command(
|
||||
text: str,
|
||||
available_skills: set[str] | None = None,
|
||||
@@ -670,6 +723,7 @@ class ChannelManager:
|
||||
assistant_id: str = DEFAULT_ASSISTANT_ID,
|
||||
default_session: dict[str, Any] | None = None,
|
||||
channel_sessions: dict[str, Any] | None = None,
|
||||
connection_repo: Any | None = None,
|
||||
) -> None:
|
||||
self.bus = bus
|
||||
self.store = store
|
||||
@@ -679,7 +733,9 @@ class ChannelManager:
|
||||
self._assistant_id = assistant_id
|
||||
self._default_session = _as_dict(default_session)
|
||||
self._channel_sessions = dict(channel_sessions or {})
|
||||
self._connection_repo = connection_repo
|
||||
self._client = None # lazy init — langgraph_sdk async client
|
||||
self._channel_metadata_synced: set[str] = set()
|
||||
self._skill_storage: SkillStorage | None = None
|
||||
self._csrf_token = generate_csrf_token()
|
||||
self._semaphore: asyncio.Semaphore | None = None
|
||||
@@ -728,12 +784,17 @@ class ChannelManager:
|
||||
configurable["checkpoint_ns"] = ""
|
||||
configurable["thread_id"] = thread_id
|
||||
|
||||
# ``user_id`` drives user-scoped filesystem buckets that only accept
|
||||
# ``[A-Za-z0-9_-]``, so normalize the channel id and keep the raw value
|
||||
# under ``channel_user_id`` for platform-facing lookups.
|
||||
# ``user_id`` drives DeerFlow-owned memory, files, and thread buckets.
|
||||
# For browser-connected IM channels, prefer the DeerFlow account that
|
||||
# owns the connection. Preserve the raw platform user under
|
||||
# ``channel_user_id`` for platform-facing lookups and audits.
|
||||
run_context_identity: dict[str, Any] = {"thread_id": thread_id}
|
||||
owner_user_id = _effective_owner_user_id(msg)
|
||||
if owner_user_id:
|
||||
run_context_identity["user_id"] = _safe_user_id_for_run(owner_user_id)
|
||||
elif msg.user_id:
|
||||
run_context_identity["user_id"] = _safe_user_id_for_run(msg.user_id)
|
||||
if msg.user_id:
|
||||
run_context_identity["user_id"] = make_safe_user_id(msg.user_id)
|
||||
run_context_identity["channel_user_id"] = msg.user_id
|
||||
|
||||
run_context = _merge_dicts(
|
||||
@@ -845,6 +906,7 @@ class ChannelManager:
|
||||
logger.error("[Manager] unhandled error in message task: %s", exc, exc_info=exc)
|
||||
|
||||
async def _handle_message(self, msg: InboundMessage) -> None:
|
||||
msg = _apply_effective_owner(msg)
|
||||
async with self._semaphore:
|
||||
try:
|
||||
if msg.msg_type == InboundMessageType.COMMAND:
|
||||
@@ -877,10 +939,27 @@ class ChannelManager:
|
||||
|
||||
# -- chat handling -----------------------------------------------------
|
||||
|
||||
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
||||
"""Create a new thread through Gateway and store the mapping."""
|
||||
thread = await client.threads.create()
|
||||
thread_id = thread["thread_id"]
|
||||
async def _lookup_thread_id(self, msg: InboundMessage) -> str | None:
|
||||
if msg.connection_id and self._connection_repo is not None:
|
||||
return await self._connection_repo.get_thread_id(
|
||||
msg.connection_id,
|
||||
msg.chat_id,
|
||||
msg.topic_id,
|
||||
)
|
||||
return self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
||||
|
||||
async def _store_thread_id(self, msg: InboundMessage, thread_id: str) -> None:
|
||||
if msg.connection_id and msg.owner_user_id and self._connection_repo is not None:
|
||||
await self._connection_repo.set_thread_id(
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
provider=msg.channel_name,
|
||||
external_conversation_id=msg.chat_id,
|
||||
external_topic_id=msg.topic_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
return
|
||||
|
||||
self.store.set_thread_id(
|
||||
msg.channel_name,
|
||||
msg.chat_id,
|
||||
@@ -888,18 +967,49 @@ class ChannelManager:
|
||||
topic_id=msg.topic_id,
|
||||
user_id=msg.user_id,
|
||||
)
|
||||
|
||||
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
||||
"""Create a new thread through Gateway and store the mapping."""
|
||||
metadata = _thread_channel_metadata(msg)
|
||||
owner_headers = _owner_headers(msg)
|
||||
if owner_headers:
|
||||
thread = await client.threads.create(metadata=metadata, headers=owner_headers)
|
||||
else:
|
||||
thread = await client.threads.create(metadata=metadata)
|
||||
thread_id = thread["thread_id"]
|
||||
await self._store_thread_id(msg, thread_id)
|
||||
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
|
||||
return thread_id
|
||||
|
||||
async def _update_thread_channel_metadata(self, client, msg: InboundMessage, thread_id: str) -> None:
|
||||
"""Best-effort source metadata backfill for existing IM-created threads."""
|
||||
# The metadata (provider/chat/topic) is constant for a thread, so one
|
||||
# successful backfill per manager lifetime is enough — skip the
|
||||
# redundant PATCH on every subsequent inbound message.
|
||||
if thread_id in self._channel_metadata_synced:
|
||||
return
|
||||
update_kwargs: dict[str, Any] = {"metadata": _thread_channel_metadata(msg)}
|
||||
if owner_headers := _owner_headers(msg):
|
||||
update_kwargs["headers"] = owner_headers
|
||||
try:
|
||||
await client.threads.update(thread_id, **update_kwargs)
|
||||
except Exception:
|
||||
logger.debug("[Manager] failed to update channel metadata for thread_id=%s", thread_id, exc_info=True)
|
||||
return
|
||||
if len(self._channel_metadata_synced) > 4096:
|
||||
self._channel_metadata_synced.clear()
|
||||
self._channel_metadata_synced.add(thread_id)
|
||||
|
||||
async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None:
|
||||
client = self._get_client()
|
||||
|
||||
# Look up existing DeerFlow thread.
|
||||
# topic_id may be None (e.g. Telegram private chats) — the store
|
||||
# handles this by using the "channel:chat_id" key without a topic suffix.
|
||||
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
||||
thread_id = await self._lookup_thread_id(msg)
|
||||
if thread_id:
|
||||
logger.info("[Manager] reusing thread: thread_id=%s for topic_id=%s", thread_id, msg.topic_id)
|
||||
await self._update_thread_channel_metadata(client, msg, thread_id)
|
||||
|
||||
# No existing thread found — create a new one
|
||||
if thread_id is None:
|
||||
@@ -940,14 +1050,19 @@ class ChannelManager:
|
||||
return
|
||||
|
||||
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
||||
run_kwargs: dict[str, Any] = {
|
||||
"input": {"messages": [human_message]},
|
||||
"config": run_config,
|
||||
"context": run_context,
|
||||
"multitask_strategy": "reject",
|
||||
}
|
||||
if owner_headers := _owner_headers(msg):
|
||||
run_kwargs["headers"] = owner_headers
|
||||
try:
|
||||
result = await client.runs.wait(
|
||||
thread_id,
|
||||
assistant_id,
|
||||
input={"messages": [human_message]},
|
||||
config=run_config,
|
||||
context=run_context,
|
||||
multitask_strategy="reject",
|
||||
**run_kwargs,
|
||||
)
|
||||
except Exception as exc:
|
||||
if _is_thread_busy_error(exc):
|
||||
@@ -984,6 +1099,8 @@ class ChannelManager:
|
||||
artifacts=artifacts,
|
||||
attachments=attachments,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
|
||||
)
|
||||
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
|
||||
@@ -1008,16 +1125,21 @@ class ChannelManager:
|
||||
last_published_text = ""
|
||||
last_publish_at = 0.0
|
||||
stream_error: BaseException | None = None
|
||||
stream_kwargs: dict[str, Any] = {
|
||||
"input": {"messages": [human_message]},
|
||||
"config": run_config,
|
||||
"context": run_context,
|
||||
"stream_mode": ["messages-tuple", "values"],
|
||||
"multitask_strategy": "reject",
|
||||
}
|
||||
if owner_headers := _owner_headers(msg):
|
||||
stream_kwargs["headers"] = owner_headers
|
||||
|
||||
try:
|
||||
async for chunk in client.runs.stream(
|
||||
thread_id,
|
||||
assistant_id,
|
||||
input={"messages": [human_message]},
|
||||
config=run_config,
|
||||
context=run_context,
|
||||
stream_mode=["messages-tuple", "values"],
|
||||
multitask_strategy="reject",
|
||||
**stream_kwargs,
|
||||
):
|
||||
event = getattr(chunk, "event", "")
|
||||
data = getattr(chunk, "data", None)
|
||||
@@ -1047,6 +1169,8 @@ class ChannelManager:
|
||||
text=latest_text,
|
||||
is_final=False,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_response_metadata(msg.metadata),
|
||||
)
|
||||
)
|
||||
@@ -1093,6 +1217,8 @@ class ChannelManager:
|
||||
attachments=attachments,
|
||||
is_final=True,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
|
||||
)
|
||||
)
|
||||
@@ -1124,18 +1250,10 @@ class ChannelManager:
|
||||
if reply is None and command == "new":
|
||||
# Create a new thread through Gateway
|
||||
client = self._get_client()
|
||||
thread = await client.threads.create()
|
||||
new_thread_id = thread["thread_id"]
|
||||
self.store.set_thread_id(
|
||||
msg.channel_name,
|
||||
msg.chat_id,
|
||||
new_thread_id,
|
||||
topic_id=msg.topic_id,
|
||||
user_id=msg.user_id,
|
||||
)
|
||||
await self._create_thread(client, msg)
|
||||
reply = "New conversation started."
|
||||
elif reply is None and command == "status":
|
||||
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
||||
thread_id = await self._lookup_thread_id(msg)
|
||||
reply = f"Active thread: {thread_id}" if thread_id else "No active conversation."
|
||||
elif reply is None and command == "models":
|
||||
reply = await self._fetch_gateway("/api/models", "models")
|
||||
@@ -1174,9 +1292,11 @@ class ChannelManager:
|
||||
outbound = OutboundMessage(
|
||||
channel_name=msg.channel_name,
|
||||
chat_id=msg.chat_id,
|
||||
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "",
|
||||
thread_id=await self._lookup_thread_id(msg) or "",
|
||||
text=reply,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_slim_metadata(msg.metadata),
|
||||
)
|
||||
await self.bus.publish_outbound(outbound)
|
||||
@@ -1212,9 +1332,11 @@ class ChannelManager:
|
||||
outbound = OutboundMessage(
|
||||
channel_name=msg.channel_name,
|
||||
chat_id=msg.chat_id,
|
||||
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "",
|
||||
thread_id=await self._lookup_thread_id(msg) or "",
|
||||
text=error_text,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_slim_metadata(msg.metadata),
|
||||
)
|
||||
await self.bus.publish_outbound(outbound)
|
||||
|
||||
@@ -44,6 +44,12 @@ class InboundMessage:
|
||||
Messages sharing the same ``topic_id`` within a ``chat_id`` will
|
||||
reuse the same DeerFlow thread. When ``None``, each message
|
||||
creates a new thread (one-shot Q&A).
|
||||
connection_id: Optional DeerFlow channel connection id. When present,
|
||||
conversation mapping is scoped by the connection instead of the
|
||||
legacy global ``channel_name:chat_id[:topic_id]`` key.
|
||||
owner_user_id: DeerFlow user id that owns the channel connection.
|
||||
Platform user ids stay in ``user_id``.
|
||||
workspace_id: Optional external workspace/guild/team id.
|
||||
files: Optional list of file attachments (platform-specific dicts).
|
||||
metadata: Arbitrary extra data from the channel.
|
||||
created_at: Unix timestamp when the message was created.
|
||||
@@ -56,6 +62,9 @@ class InboundMessage:
|
||||
msg_type: InboundMessageType = InboundMessageType.CHAT
|
||||
thread_ts: str | None = None
|
||||
topic_id: str | None = None
|
||||
connection_id: str | None = None
|
||||
owner_user_id: str | None = None
|
||||
workspace_id: str | None = None
|
||||
files: list[dict[str, Any]] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
created_at: float = field(default_factory=time.time)
|
||||
@@ -95,6 +104,9 @@ class OutboundMessage:
|
||||
is_final: Whether this is the final message in the response stream.
|
||||
thread_ts: Optional platform thread identifier for threaded replies.
|
||||
metadata: Arbitrary extra data.
|
||||
connection_id: Optional DeerFlow channel connection id used for
|
||||
connection-specific outbound credentials.
|
||||
owner_user_id: DeerFlow user id that owns the channel connection.
|
||||
created_at: Unix timestamp.
|
||||
"""
|
||||
|
||||
@@ -106,6 +118,8 @@ class OutboundMessage:
|
||||
attachments: list[ResolvedAttachment] = field(default_factory=list)
|
||||
is_final: bool = True
|
||||
thread_ts: str | None = None
|
||||
connection_id: str | None = None
|
||||
owner_user_id: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@@ -0,0 +1,154 @@
|
||||
"""Local persistence for runtime IM channel configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RUNTIME_CHANNEL_DISABLED_FLAG = "_runtime_disabled"
|
||||
|
||||
|
||||
class ChannelRuntimeConfigStore:
|
||||
"""JSON-backed store for channel credentials entered from the UI.
|
||||
|
||||
This intentionally mirrors ``ChannelStore``: local/private deployments get
|
||||
durable runtime configuration without needing a public callback URL or a
|
||||
config.yaml edit.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str | Path | None = None) -> None:
|
||||
if path is None:
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
path = Path(get_paths().base_dir) / "channels" / "runtime-config.json"
|
||||
self._path = Path(path)
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._data: dict[str, dict[str, Any]] = self._load()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _load(self) -> dict[str, dict[str, Any]]:
|
||||
if self._path.exists():
|
||||
try:
|
||||
raw = json.loads(self._path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError):
|
||||
logger.warning("Corrupt channel runtime config store at %s, starting fresh", self._path)
|
||||
return {}
|
||||
if isinstance(raw, dict):
|
||||
return {str(name): dict(value) for name, value in raw.items() if isinstance(value, dict)}
|
||||
return {}
|
||||
|
||||
def _save(self) -> None:
|
||||
fd = tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
dir=self._path.parent,
|
||||
suffix=".tmp",
|
||||
delete=False,
|
||||
)
|
||||
try:
|
||||
json.dump(self._data, fd, indent=2, ensure_ascii=False)
|
||||
fd.close()
|
||||
Path(fd.name).replace(self._path)
|
||||
try:
|
||||
self._path.chmod(0o600)
|
||||
except OSError:
|
||||
logger.debug("Unable to chmod channel runtime config store at %s", self._path, exc_info=True)
|
||||
except BaseException:
|
||||
fd.close()
|
||||
Path(fd.name).unlink(missing_ok=True)
|
||||
raise
|
||||
|
||||
def load_all(self) -> dict[str, dict[str, Any]]:
|
||||
with self._lock:
|
||||
return {name: dict(config) for name, config in self._data.items()}
|
||||
|
||||
def get_provider_config(self, provider: str) -> dict[str, Any] | None:
|
||||
with self._lock:
|
||||
config = self._data.get(provider)
|
||||
return dict(config) if isinstance(config, dict) else None
|
||||
|
||||
def set_provider_config(self, provider: str, config: dict[str, Any]) -> None:
|
||||
with self._lock:
|
||||
self._data[provider] = dict(config)
|
||||
self._save()
|
||||
|
||||
def set_provider_disconnected(self, provider: str) -> None:
|
||||
with self._lock:
|
||||
self._data[provider] = {
|
||||
"enabled": False,
|
||||
RUNTIME_CHANNEL_DISABLED_FLAG: True,
|
||||
}
|
||||
self._save()
|
||||
|
||||
def remove_provider_config(self, provider: str) -> bool:
|
||||
with self._lock:
|
||||
if provider not in self._data:
|
||||
return False
|
||||
del self._data[provider]
|
||||
self._save()
|
||||
return True
|
||||
|
||||
|
||||
def _provider_enabled(channel_connections_config: Any, provider: str) -> bool:
|
||||
provider_config = getattr(channel_connections_config, provider, None)
|
||||
return bool(getattr(provider_config, "enabled", False))
|
||||
|
||||
|
||||
def _runtime_channel_disconnected(runtime_config: dict[str, Any]) -> bool:
|
||||
return runtime_config.get(RUNTIME_CHANNEL_DISABLED_FLAG) is True and runtime_config.get("enabled") is False
|
||||
|
||||
|
||||
def merge_runtime_channel_configs(
|
||||
channels_config: dict[str, Any],
|
||||
channel_connections_config: Any,
|
||||
*,
|
||||
store: ChannelRuntimeConfigStore | None = None,
|
||||
) -> None:
|
||||
"""Merge persisted runtime provider config into ``channels_config`` in-place."""
|
||||
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
|
||||
return
|
||||
|
||||
runtime_store = store or ChannelRuntimeConfigStore()
|
||||
for provider, runtime_config in runtime_store.load_all().items():
|
||||
if not _provider_enabled(channel_connections_config, provider):
|
||||
continue
|
||||
if _runtime_channel_disconnected(runtime_config):
|
||||
channels_config.pop(provider, None)
|
||||
continue
|
||||
existing = channels_config.get(provider)
|
||||
merged = dict(runtime_config)
|
||||
if isinstance(existing, dict):
|
||||
merged.update(existing)
|
||||
channels_config[provider] = merged
|
||||
|
||||
|
||||
def apply_runtime_connection_config(
|
||||
channel_connections_config: Any,
|
||||
*,
|
||||
store: ChannelRuntimeConfigStore | None = None,
|
||||
) -> Any:
|
||||
"""Apply persisted connection metadata that lives outside ``channels``.
|
||||
|
||||
Telegram uses a bot username for deep links; UI-entered values are stored
|
||||
with the runtime channel config so local restarts keep the provider
|
||||
configured.
|
||||
"""
|
||||
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
|
||||
return channel_connections_config
|
||||
|
||||
runtime_store = store or ChannelRuntimeConfigStore()
|
||||
telegram_runtime_config = runtime_store.get_provider_config("telegram")
|
||||
bot_username = ""
|
||||
if isinstance(telegram_runtime_config, dict):
|
||||
bot_username = str(telegram_runtime_config.get("bot_username") or "").strip()
|
||||
if not bot_username or not _provider_enabled(channel_connections_config, "telegram"):
|
||||
return channel_connections_config
|
||||
|
||||
config = channel_connections_config.model_copy(deep=True)
|
||||
config.telegram.bot_username = bot_username
|
||||
return config
|
||||
+145
-26
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@@ -9,6 +10,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from app.channels.base import Channel
|
||||
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
|
||||
from app.channels.message_bus import MessageBus
|
||||
from app.channels.runtime_config_store import merge_runtime_channel_configs
|
||||
from app.channels.store import ChannelStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -42,6 +44,11 @@ _CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
|
||||
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
|
||||
|
||||
|
||||
def _channel_has_credentials(name: str, channel_config: dict[str, Any]) -> bool:
|
||||
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
|
||||
return any(not isinstance(channel_config.get(key), bool) and channel_config.get(key) is not None and str(channel_config[key]).strip() for key in cred_keys)
|
||||
|
||||
|
||||
def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str, default: str) -> str:
|
||||
value = config.pop(config_key, None)
|
||||
if isinstance(value, str) and value.strip():
|
||||
@@ -52,6 +59,30 @@ def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str,
|
||||
return default
|
||||
|
||||
|
||||
def _merge_channel_connection_runtime_config(channels_config: dict[str, Any], app_config: AppConfig) -> None:
|
||||
connection_config = getattr(app_config, "channel_connections", None)
|
||||
merge_runtime_channel_configs(channels_config, connection_config)
|
||||
|
||||
|
||||
def _make_connection_repo(app_config: AppConfig):
|
||||
connection_config = getattr(app_config, "channel_connections", None)
|
||||
if connection_config is None or not getattr(connection_config, "enabled", False):
|
||||
return None
|
||||
|
||||
try:
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
except Exception:
|
||||
logger.exception("Failed to import channel connection repository")
|
||||
return None
|
||||
|
||||
session_factory = get_session_factory()
|
||||
if session_factory is None:
|
||||
logger.warning("Channel connections are enabled but database persistence is not available")
|
||||
return None
|
||||
return ChannelConnectionRepository(session_factory)
|
||||
|
||||
|
||||
class ChannelService:
|
||||
"""Manages the lifecycle of all configured IM channels.
|
||||
|
||||
@@ -59,9 +90,10 @@ class ChannelService:
|
||||
instantiates enabled channels, and starts the ChannelManager dispatcher.
|
||||
"""
|
||||
|
||||
def __init__(self, channels_config: dict[str, Any] | None = None) -> None:
|
||||
def __init__(self, channels_config: dict[str, Any] | None = None, *, connection_repo: Any | None = None) -> None:
|
||||
self.bus = MessageBus()
|
||||
self.store = ChannelStore()
|
||||
self._connection_repo = connection_repo
|
||||
config = dict(channels_config or {})
|
||||
langgraph_url = _resolve_service_url(config, "langgraph_url", _CHANNELS_LANGGRAPH_URL_ENV, DEFAULT_LANGGRAPH_URL)
|
||||
gateway_url = _resolve_service_url(config, "gateway_url", _CHANNELS_GATEWAY_URL_ENV, DEFAULT_GATEWAY_URL)
|
||||
@@ -74,10 +106,12 @@ class ChannelService:
|
||||
gateway_url=gateway_url,
|
||||
default_session=default_session if isinstance(default_session, dict) else None,
|
||||
channel_sessions=channel_sessions,
|
||||
connection_repo=connection_repo,
|
||||
)
|
||||
self._channels: dict[str, Any] = {} # name -> Channel instance
|
||||
self._config = config
|
||||
self._running = False
|
||||
self._readiness_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
@classmethod
|
||||
def from_app_config(cls, app_config: AppConfig | None = None) -> ChannelService:
|
||||
@@ -90,8 +124,9 @@ class ChannelService:
|
||||
# extra fields are allowed by AppConfig (extra="allow")
|
||||
extra = app_config.model_extra or {}
|
||||
if "channels" in extra:
|
||||
channels_config = extra["channels"]
|
||||
return cls(channels_config=channels_config)
|
||||
channels_config = dict(extra["channels"] or {})
|
||||
_merge_channel_connection_runtime_config(channels_config, app_config)
|
||||
return cls(channels_config=channels_config, connection_repo=_make_connection_repo(app_config))
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the manager and all enabled channels."""
|
||||
@@ -99,36 +134,83 @@ class ChannelService:
|
||||
return
|
||||
|
||||
await self.manager.start()
|
||||
self._running = True
|
||||
|
||||
ready_status = await self.ensure_ready_channels(attempts=2)
|
||||
ready_count = sum(1 for ready in ready_status.values() if ready)
|
||||
logger.info("ChannelService started with %d/%d ready channels", ready_count, len(ready_status))
|
||||
|
||||
async def ensure_ready_channels(self, *, attempts: int = 1) -> dict[str, bool]:
|
||||
"""Start or restart enabled configured channels that are not ready."""
|
||||
ready_status: dict[str, bool] = {}
|
||||
for name, channel_config in self._config.items():
|
||||
if not isinstance(channel_config, dict):
|
||||
continue
|
||||
if not channel_config.get("enabled", False):
|
||||
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
|
||||
has_creds = any(not isinstance(channel_config.get(k), bool) and channel_config.get(k) is not None and str(channel_config[k]).strip() for k in cred_keys)
|
||||
if has_creds:
|
||||
if _channel_has_credentials(name, channel_config):
|
||||
logger.warning(
|
||||
"Channel '%s' has credentials configured but is disabled. Set enabled: true under channels.%s in config.yaml to activate it.",
|
||||
name,
|
||||
name,
|
||||
"A configured channel has credentials configured but is disabled. Set enabled: true under its channels entry in config.yaml to activate it.",
|
||||
)
|
||||
else:
|
||||
logger.info("Channel %s is disabled, skipping", name)
|
||||
logger.info("A configured channel is disabled, skipping")
|
||||
continue
|
||||
|
||||
await self._start_channel(name, channel_config)
|
||||
ready_status[name] = await self.ensure_channel_ready(name, attempts=attempts)
|
||||
return ready_status
|
||||
|
||||
self._running = True
|
||||
logger.info("ChannelService started with channels: %s", list(self._channels.keys()))
|
||||
async def ensure_channel_ready(
|
||||
self,
|
||||
name: str,
|
||||
config: dict[str, Any] | None = None,
|
||||
*,
|
||||
attempts: int = 1,
|
||||
) -> bool:
|
||||
"""Ensure a single enabled channel is running using its current config."""
|
||||
if not self._running:
|
||||
logger.warning("ChannelService is not running; cannot ensure channel readiness")
|
||||
return False
|
||||
|
||||
if config is not None:
|
||||
self._config[name] = dict(config)
|
||||
|
||||
# Serialize per channel: readiness is polled from request handlers, so
|
||||
# concurrent calls must not stop/start the same channel worker twice.
|
||||
lock = self._readiness_locks.setdefault(name, asyncio.Lock())
|
||||
async with lock:
|
||||
channel_config = self._config.get(name)
|
||||
if not channel_config or not isinstance(channel_config, dict):
|
||||
logger.warning("No config for requested channel")
|
||||
return False
|
||||
if not channel_config.get("enabled", False):
|
||||
return False
|
||||
|
||||
channel = self._channels.get(name)
|
||||
if channel is not None and channel.is_running:
|
||||
return True
|
||||
|
||||
if channel is not None:
|
||||
try:
|
||||
await channel.stop()
|
||||
except Exception:
|
||||
logger.exception("Error stopping non-running channel before readiness retry")
|
||||
self._channels.pop(name, None)
|
||||
|
||||
max_attempts = max(1, attempts)
|
||||
for attempt in range(max_attempts):
|
||||
if attempt > 0:
|
||||
logger.info("Retrying channel startup after readiness check")
|
||||
if await self._start_channel(name, channel_config):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop all channels and the manager."""
|
||||
for name, channel in list(self._channels.items()):
|
||||
try:
|
||||
await channel.stop()
|
||||
logger.info("Channel %s stopped", name)
|
||||
logger.info("Channel stopped")
|
||||
except Exception:
|
||||
logger.exception("Error stopping channel %s", name)
|
||||
logger.exception("Error stopping channel")
|
||||
self._channels.clear()
|
||||
|
||||
await self.manager.stop()
|
||||
@@ -140,6 +222,9 @@ class ChannelService:
|
||||
|
||||
Uses ``get_app_config()`` which detects file changes via mtime,
|
||||
so edits to ``config.yaml`` are picked up without a process restart.
|
||||
The UI runtime-config overlay applied at startup is re-applied here
|
||||
so a file-driven reload neither drops credentials entered from the
|
||||
browser nor resurrects a channel disconnected from it.
|
||||
Falls back to the cached ``self._config`` when config loading fails.
|
||||
"""
|
||||
try:
|
||||
@@ -147,7 +232,8 @@ class ChannelService:
|
||||
|
||||
app_config = get_app_config()
|
||||
extra = app_config.model_extra or {}
|
||||
channels_config = extra.get("channels", {})
|
||||
channels_config = dict(extra.get("channels") or {})
|
||||
_merge_channel_connection_runtime_config(channels_config, app_config)
|
||||
channel_config = channels_config.get(name)
|
||||
if isinstance(channel_config, dict):
|
||||
# Update the cached config so get_status() stays consistent.
|
||||
@@ -157,18 +243,23 @@ class ChannelService:
|
||||
logger.exception("Failed to reload config for channel %s, using cached version", name)
|
||||
return self._config.get(name)
|
||||
|
||||
async def restart_channel(self, name: str) -> bool:
|
||||
async def restart_channel(self, name: str, *, reload_config: bool = True) -> bool:
|
||||
"""Restart a specific channel. Returns True if successful."""
|
||||
if name in self._channels:
|
||||
try:
|
||||
await self._channels[name].stop()
|
||||
except Exception:
|
||||
logger.exception("Error stopping channel %s for restart", name)
|
||||
logger.exception("Error stopping channel for restart")
|
||||
del self._channels[name]
|
||||
|
||||
config = self._load_channel_config(name)
|
||||
if reload_config:
|
||||
# Reading config.yaml and the runtime store is disk IO; keep it
|
||||
# off the event loop.
|
||||
config = await asyncio.to_thread(self._load_channel_config, name)
|
||||
else:
|
||||
config = self._config.get(name)
|
||||
if not config or not isinstance(config, dict):
|
||||
logger.warning("No config for channel %s", name)
|
||||
logger.warning("No config for requested channel")
|
||||
return False
|
||||
|
||||
if not config.get("enabled", False):
|
||||
@@ -177,11 +268,35 @@ class ChannelService:
|
||||
|
||||
return await self._start_channel(name, config)
|
||||
|
||||
async def configure_channel(self, name: str, config: dict[str, Any]) -> bool:
|
||||
"""Apply runtime config for a channel and restart it if the service is running."""
|
||||
self._config[name] = dict(config)
|
||||
if not self._running:
|
||||
return True
|
||||
# The caller just supplied the authoritative config (e.g. credentials
|
||||
# entered in the browser that are never written to config.yaml) — a
|
||||
# file reload here would clobber it with the stale on-disk entry.
|
||||
return await self.restart_channel(name, reload_config=False)
|
||||
|
||||
async def remove_channel(self, name: str) -> bool:
|
||||
"""Remove runtime config for a channel and stop it if currently running."""
|
||||
self._config.pop(name, None)
|
||||
channel = self._channels.pop(name, None)
|
||||
if channel is None:
|
||||
return True
|
||||
try:
|
||||
await channel.stop()
|
||||
logger.info("Channel stopped and removed")
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Error stopping channel for removal")
|
||||
return False
|
||||
|
||||
async def _start_channel(self, name: str, config: dict[str, Any]) -> bool:
|
||||
"""Instantiate and start a single channel."""
|
||||
import_path = _CHANNEL_REGISTRY.get(name)
|
||||
if not import_path:
|
||||
logger.warning("Unknown channel type: %s", name)
|
||||
logger.warning("Unknown channel type")
|
||||
return False
|
||||
|
||||
try:
|
||||
@@ -189,24 +304,26 @@ class ChannelService:
|
||||
|
||||
channel_cls = resolve_class(import_path, base_class=None)
|
||||
except Exception:
|
||||
logger.exception("Failed to import channel class for %s", name)
|
||||
logger.exception("Failed to import channel class")
|
||||
return False
|
||||
|
||||
try:
|
||||
config = dict(config)
|
||||
config["channel_store"] = self.store
|
||||
if self._connection_repo is not None:
|
||||
config["connection_repo"] = self._connection_repo
|
||||
channel = channel_cls(bus=self.bus, config=config)
|
||||
self._channels[name] = channel
|
||||
await channel.start()
|
||||
if not channel.is_running:
|
||||
self._channels.pop(name, None)
|
||||
logger.error("Channel %s did not enter a running state after start()", name)
|
||||
logger.error("Channel did not enter a running state after start()")
|
||||
return False
|
||||
logger.info("Channel %s started", name)
|
||||
logger.info("Channel started")
|
||||
return True
|
||||
except Exception:
|
||||
self._channels.pop(name, None)
|
||||
logger.exception("Failed to start channel %s", name)
|
||||
logger.exception("Failed to start channel")
|
||||
return False
|
||||
|
||||
def get_status(self) -> dict[str, Any]:
|
||||
@@ -245,7 +362,9 @@ async def start_channel_service(app_config: AppConfig | None = None) -> ChannelS
|
||||
global _channel_service
|
||||
if _channel_service is not None:
|
||||
return _channel_service
|
||||
_channel_service = ChannelService.from_app_config(app_config)
|
||||
# from_app_config reads the JSON channel store and runtime config files;
|
||||
# keep that disk IO off the event loop.
|
||||
_channel_service = await asyncio.to_thread(ChannelService.from_app_config, app_config)
|
||||
await _channel_service.start()
|
||||
return _channel_service
|
||||
|
||||
|
||||
+148
-26
@@ -9,7 +9,8 @@ from typing import Any
|
||||
from markdown_to_mrkdwn import SlackMarkdownConverter
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -64,6 +65,9 @@ class SlackChannel(Channel):
|
||||
self._web_client = None
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
self._web_client_factory = config.get("web_client_factory")
|
||||
self._connection_web_clients: dict[str, tuple[str, Any]] = {}
|
||||
configured_bot_user_id = config.get("bot_user_id")
|
||||
self._bot_user_id = str(configured_bot_user_id).lstrip("@") if configured_bot_user_id else None
|
||||
|
||||
@@ -80,26 +84,28 @@ class SlackChannel(Channel):
|
||||
return
|
||||
|
||||
self._SocketModeResponse = SocketModeResponse
|
||||
if self._web_client_factory is None:
|
||||
self._web_client_factory = WebClient
|
||||
|
||||
bot_token = self.config.get("bot_token", "")
|
||||
app_token = self.config.get("app_token", "")
|
||||
|
||||
if self._connection_repo is not None and self.config.get("event_delivery") == "http":
|
||||
if not bot_token:
|
||||
logger.error("Slack HTTP Events mode requires bot_token")
|
||||
return
|
||||
await self._initialize_operator_web_client(str(bot_token))
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self._running = True
|
||||
self.bus.subscribe_outbound(self._on_outbound)
|
||||
logger.info("Slack channel started in HTTP Events mode")
|
||||
return
|
||||
|
||||
if not bot_token or not app_token:
|
||||
logger.error("Slack channel requires bot_token and app_token")
|
||||
return
|
||||
|
||||
self._web_client = WebClient(token=bot_token)
|
||||
if self._bot_user_id is None:
|
||||
try:
|
||||
auth_info = await asyncio.to_thread(self._web_client.auth_test)
|
||||
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
|
||||
if user_id is None:
|
||||
auth_get = getattr(auth_info, "get", None)
|
||||
user_id = auth_get("user_id") if callable(auth_get) else None
|
||||
if isinstance(user_id, str) and user_id:
|
||||
self._bot_user_id = user_id
|
||||
except Exception:
|
||||
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
|
||||
await self._initialize_operator_web_client(str(bot_token))
|
||||
self._socket_client = SocketModeClient(
|
||||
app_token=app_token,
|
||||
web_client=self._web_client,
|
||||
@@ -124,7 +130,8 @@ class SlackChannel(Channel):
|
||||
logger.info("Slack channel stopped")
|
||||
|
||||
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||||
if not self._web_client:
|
||||
web_client = await self._get_web_client_for_message(msg)
|
||||
if not web_client:
|
||||
return
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
@@ -137,11 +144,12 @@ class SlackChannel(Channel):
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(_max_retries):
|
||||
try:
|
||||
await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs)
|
||||
await asyncio.to_thread(web_client.chat_postMessage, **kwargs)
|
||||
# Add a completion reaction to the thread root
|
||||
if msg.thread_ts:
|
||||
await asyncio.to_thread(
|
||||
self._add_reaction,
|
||||
self._add_reaction_with_client,
|
||||
web_client,
|
||||
msg.chat_id,
|
||||
msg.thread_ts,
|
||||
"white_check_mark",
|
||||
@@ -165,7 +173,8 @@ class SlackChannel(Channel):
|
||||
if msg.thread_ts:
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self._add_reaction,
|
||||
self._add_reaction_with_client,
|
||||
web_client,
|
||||
msg.chat_id,
|
||||
msg.thread_ts,
|
||||
"x",
|
||||
@@ -177,7 +186,8 @@ class SlackChannel(Channel):
|
||||
raise last_exc
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
if not self._web_client:
|
||||
web_client = await self._get_web_client_for_message(msg)
|
||||
if not web_client:
|
||||
return False
|
||||
|
||||
try:
|
||||
@@ -190,7 +200,7 @@ class SlackChannel(Channel):
|
||||
if msg.thread_ts:
|
||||
kwargs["thread_ts"] = msg.thread_ts
|
||||
|
||||
await asyncio.to_thread(self._web_client.files_upload_v2, **kwargs)
|
||||
await asyncio.to_thread(web_client.files_upload_v2, **kwargs)
|
||||
logger.info("[Slack] file uploaded: %s to channel=%s", attachment.filename, msg.chat_id)
|
||||
return True
|
||||
except Exception:
|
||||
@@ -199,12 +209,45 @@ class SlackChannel(Channel):
|
||||
|
||||
# -- internal ----------------------------------------------------------
|
||||
|
||||
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
|
||||
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
|
||||
if not self._web_client:
|
||||
async def _initialize_operator_web_client(self, bot_token: str) -> None:
|
||||
self._web_client = self._web_client_factory(token=bot_token)
|
||||
if self._bot_user_id is not None:
|
||||
return
|
||||
try:
|
||||
self._web_client.reactions_add(
|
||||
auth_info = await asyncio.to_thread(self._web_client.auth_test)
|
||||
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
|
||||
if user_id is None:
|
||||
auth_get = getattr(auth_info, "get", None)
|
||||
user_id = auth_get("user_id") if callable(auth_get) else None
|
||||
if isinstance(user_id, str) and user_id:
|
||||
self._bot_user_id = user_id
|
||||
except Exception:
|
||||
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
|
||||
|
||||
async def _get_web_client_for_message(self, msg: OutboundMessage):
|
||||
if msg.connection_id and self._connection_repo is not None:
|
||||
credentials = await self._connection_repo.get_credentials(msg.connection_id)
|
||||
access_token = credentials.get("access_token") if credentials else None
|
||||
if not access_token:
|
||||
return self._web_client
|
||||
# WebClient keeps its own HTTP session and rate-limit state, so
|
||||
# reuse one per connection until its token changes.
|
||||
cached = self._connection_web_clients.get(msg.connection_id)
|
||||
if cached is not None and cached[0] == access_token:
|
||||
return cached[1]
|
||||
if self._web_client_factory is None:
|
||||
from slack_sdk import WebClient
|
||||
|
||||
self._web_client_factory = WebClient
|
||||
web_client = self._web_client_factory(token=access_token)
|
||||
self._connection_web_clients[msg.connection_id] = (access_token, web_client)
|
||||
return web_client
|
||||
return self._web_client
|
||||
|
||||
@staticmethod
|
||||
def _add_reaction_with_client(web_client, channel_id: str, timestamp: str, emoji: str) -> None:
|
||||
try:
|
||||
web_client.reactions_add(
|
||||
channel=channel_id,
|
||||
timestamp=timestamp,
|
||||
name=emoji,
|
||||
@@ -213,6 +256,12 @@ class SlackChannel(Channel):
|
||||
if "already_reacted" not in str(exc):
|
||||
logger.warning("[Slack] failed to add reaction %s: %s", emoji, exc)
|
||||
|
||||
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
|
||||
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
|
||||
if not self._web_client:
|
||||
return
|
||||
self._add_reaction_with_client(self._web_client, channel_id, timestamp, emoji)
|
||||
|
||||
def _send_running_reply(self, channel_id: str, thread_ts: str) -> None:
|
||||
"""Send a 'Working on it......' reply in the thread (called from SDK thread)."""
|
||||
if not self._web_client:
|
||||
@@ -249,12 +298,15 @@ class SlackChannel(Channel):
|
||||
|
||||
# Handle message events (DM or @mention)
|
||||
if etype in ("message", "app_mention"):
|
||||
self._handle_message_event(event)
|
||||
self._handle_message_event(
|
||||
event,
|
||||
team_id=req.payload.get("team_id") or req.payload.get("team") or event.get("team"),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error processing Slack event")
|
||||
|
||||
def _handle_message_event(self, event: dict) -> None:
|
||||
def _handle_message_event(self, event: dict, *, team_id: str | None = None) -> None:
|
||||
# Ignore bot messages
|
||||
if event.get("bot_id") or event.get("subtype"):
|
||||
return
|
||||
@@ -272,6 +324,19 @@ class SlackChannel(Channel):
|
||||
if not text:
|
||||
return
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code:
|
||||
if self._loop and self._loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._bind_connection_from_connect_code(
|
||||
event=event,
|
||||
team_id=str(team_id or event.get("team") or ""),
|
||||
code=connect_code,
|
||||
),
|
||||
self._loop,
|
||||
)
|
||||
return
|
||||
|
||||
channel_id = event.get("channel", "")
|
||||
thread_ts = event.get("thread_ts") or event.get("ts", "")
|
||||
|
||||
@@ -297,4 +362,61 @@ class SlackChannel(Channel):
|
||||
self._add_reaction(channel_id, event.get("ts", thread_ts), "eyes")
|
||||
# Send "running" reply first (fire-and-forget from SDK thread)
|
||||
self._send_running_reply(channel_id, thread_ts)
|
||||
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
|
||||
if self._connection_repo is None:
|
||||
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
|
||||
else:
|
||||
asyncio.run_coroutine_threadsafe(self._publish_inbound_with_connection(inbound, team_id=team_id), self._loop)
|
||||
|
||||
async def _publish_inbound_with_connection(self, inbound, *, team_id: str | None = None) -> None:
|
||||
inbound = await self._attach_connection_identity(inbound, team_id=team_id)
|
||||
await self.bus.publish_inbound(inbound)
|
||||
|
||||
async def _attach_connection_identity(self, inbound, *, team_id: str | None = None):
|
||||
workspace_id = str(team_id or inbound.metadata.get("team_id") or "")
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="slack",
|
||||
workspace_id=workspace_id,
|
||||
)
|
||||
|
||||
async def _bind_connection_from_connect_code(self, *, event: dict, team_id: str, code: str) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
return False
|
||||
|
||||
channel_id = str(event.get("channel") or "")
|
||||
thread_ts = str(event.get("thread_ts") or event.get("ts") or "")
|
||||
state = await self._connection_repo.consume_oauth_state(provider="slack", state=code)
|
||||
if state is None:
|
||||
await self._post_connection_reply(channel_id, "Slack connection code is invalid or expired.", thread_ts)
|
||||
return True
|
||||
|
||||
user_id = str(event.get("user") or "")
|
||||
if not user_id or not team_id:
|
||||
await self._post_connection_reply(channel_id, "Slack connection could not be completed from this message.", thread_ts)
|
||||
return True
|
||||
|
||||
await self._connection_repo.upsert_connection(
|
||||
owner_user_id=state["owner_user_id"],
|
||||
provider="slack",
|
||||
external_account_id=user_id,
|
||||
workspace_id=team_id,
|
||||
metadata={
|
||||
"team_id": team_id,
|
||||
"channel_id": channel_id,
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
await self._post_connection_reply(channel_id, "Slack connected to DeerFlow.", thread_ts)
|
||||
return True
|
||||
|
||||
async def _post_connection_reply(self, channel_id: str, text: str, thread_ts: str | None = None) -> None:
|
||||
if not self._web_client or not channel_id:
|
||||
return
|
||||
kwargs: dict[str, Any] = {"channel": channel_id, "text": text}
|
||||
if thread_ts:
|
||||
kwargs["thread_ts"] = thread_ts
|
||||
try:
|
||||
await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs)
|
||||
except Exception:
|
||||
logger.exception("[Slack] failed to send connection reply in channel=%s", channel_id)
|
||||
|
||||
@@ -8,6 +8,7 @@ import threading
|
||||
from typing import Any
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -35,6 +36,7 @@ class TelegramChannel(Channel):
|
||||
pass
|
||||
# chat_id -> last sent message_id for threaded replies
|
||||
self._last_bot_message: dict[str, int] = {}
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
@@ -233,6 +235,54 @@ class TelegramChannel(Channel):
|
||||
return True
|
||||
return user_id in self._allowed_users
|
||||
|
||||
@staticmethod
|
||||
def _telegram_display_name(user) -> str:
|
||||
full_name = getattr(user, "full_name", None)
|
||||
if isinstance(full_name, str) and full_name:
|
||||
return full_name
|
||||
username = getattr(user, "username", None)
|
||||
if isinstance(username, str) and username:
|
||||
return username
|
||||
return str(getattr(user, "id", ""))
|
||||
|
||||
async def _bind_connection_from_start_token(self, update, state_token: str) -> bool:
|
||||
if self._connection_repo is None or not state_token:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="telegram", state=state_token)
|
||||
if state is None:
|
||||
await update.message.reply_text("Telegram connection link is invalid or expired.")
|
||||
return True
|
||||
|
||||
owner_user_id = state["owner_user_id"]
|
||||
user_id = str(update.effective_user.id)
|
||||
chat_id = str(update.effective_chat.id)
|
||||
connection = await self._connection_repo.upsert_connection(
|
||||
owner_user_id=owner_user_id,
|
||||
provider="telegram",
|
||||
external_account_id=user_id,
|
||||
external_account_name=self._telegram_display_name(update.effective_user),
|
||||
workspace_id=chat_id,
|
||||
workspace_name=None,
|
||||
metadata={
|
||||
"chat_id": chat_id,
|
||||
"chat_type": update.effective_chat.type,
|
||||
"telegram_username": getattr(update.effective_user, "username", None),
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
logger.info("[Telegram] bound chat=%s user=%s to DeerFlow user=%s connection=%s", chat_id, user_id, owner_user_id, connection["id"])
|
||||
await update.message.reply_text("Telegram connected to DeerFlow.")
|
||||
return True
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="telegram",
|
||||
workspace_id=inbound.chat_id,
|
||||
)
|
||||
|
||||
def _get_bot_username(self, context) -> str | None:
|
||||
bot = getattr(context, "bot", None)
|
||||
username = getattr(bot, "username", None)
|
||||
@@ -264,6 +314,11 @@ class TelegramChannel(Channel):
|
||||
"""Handle /start command."""
|
||||
if not self._check_user(update.effective_user.id):
|
||||
return
|
||||
args = getattr(context, "args", []) if context is not None else []
|
||||
if args:
|
||||
handled = await self._bind_connection_from_start_token(update, str(args[0]))
|
||||
if handled:
|
||||
return
|
||||
await update.message.reply_text("Welcome to DeerFlow! Send me a message to start a conversation.\nType /help for available commands.")
|
||||
|
||||
async def _process_incoming_with_reply(self, chat_id: str, msg_id: int, inbound: InboundMessage) -> None:
|
||||
@@ -299,6 +354,7 @@ class TelegramChannel(Channel):
|
||||
thread_ts=msg_id,
|
||||
)
|
||||
inbound.topic_id = topic_id
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
||||
@@ -341,6 +397,7 @@ class TelegramChannel(Channel):
|
||||
thread_ts=msg_id,
|
||||
)
|
||||
inbound.topic_id = topic_id
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
||||
|
||||
@@ -22,8 +22,9 @@ from cryptography.hazmat.primitives import padding
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -253,6 +254,7 @@ class WechatChannel(Channel):
|
||||
self._state_dir = self._resolve_state_dir(config.get("state_dir"))
|
||||
self._cursor_path = self._state_dir / "wechat-getupdates.json" if self._state_dir else None
|
||||
self._auth_path = self._state_dir / "wechat-auth.json" if self._state_dir else None
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
self._load_state()
|
||||
|
||||
async def start(self) -> None:
|
||||
@@ -617,6 +619,16 @@ class WechatChannel(Channel):
|
||||
if thread_ts:
|
||||
self._context_tokens_by_thread[thread_ts] = context_token
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and self._connection_repo is not None:
|
||||
handled = await self._bind_connection_from_connect_code(
|
||||
chat_id=chat_id,
|
||||
context_token=context_token,
|
||||
code=connect_code,
|
||||
)
|
||||
if handled:
|
||||
return
|
||||
|
||||
inbound = self._make_inbound(
|
||||
chat_id=chat_id,
|
||||
user_id=chat_id,
|
||||
@@ -632,8 +644,54 @@ class WechatChannel(Channel):
|
||||
},
|
||||
)
|
||||
inbound.topic_id = None
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
await self.bus.publish_inbound(inbound)
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="wechat",
|
||||
workspace_id=inbound.chat_id,
|
||||
)
|
||||
|
||||
async def _bind_connection_from_connect_code(self, *, chat_id: str, context_token: str, code: str) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="wechat", state=code)
|
||||
if state is None:
|
||||
await self._send_connection_reply(chat_id, context_token, "WeChat connection code is invalid or expired.")
|
||||
return True
|
||||
|
||||
if not chat_id:
|
||||
await self._send_connection_reply(chat_id, context_token, "WeChat connection could not be completed from this message.")
|
||||
return True
|
||||
|
||||
await self._connection_repo.upsert_connection(
|
||||
owner_user_id=state["owner_user_id"],
|
||||
provider="wechat",
|
||||
external_account_id=chat_id,
|
||||
workspace_id=chat_id,
|
||||
metadata={
|
||||
"context_token": context_token,
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
await self._send_connection_reply(chat_id, context_token, "WeChat connected to DeerFlow.")
|
||||
return True
|
||||
|
||||
async def _send_connection_reply(self, chat_id: str, context_token: str, text: str) -> None:
|
||||
if not context_token:
|
||||
return
|
||||
await self._send_text_message(
|
||||
chat_id=chat_id,
|
||||
context_token=context_token,
|
||||
text=text,
|
||||
client_id_prefix="deerflow-connect",
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
async def _ensure_authenticated(self) -> bool:
|
||||
async with self._auth_lock:
|
||||
if self._bot_token:
|
||||
|
||||
@@ -8,8 +8,10 @@ from collections.abc import Awaitable, Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import (
|
||||
InboundMessage,
|
||||
InboundMessageType,
|
||||
MessageBus,
|
||||
OutboundMessage,
|
||||
@@ -29,6 +31,7 @@ class WeComChannel(Channel):
|
||||
self._ws_frames: dict[str, dict[str, Any]] = {}
|
||||
self._ws_stream_ids: dict[str, str] = {}
|
||||
self._working_message = "Working on it..."
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
@@ -271,6 +274,16 @@ class WeComChannel(Channel):
|
||||
|
||||
user_id = (body.get("from") or {}).get("userid")
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and self._connection_repo is not None:
|
||||
handled = await self._bind_connection_from_connect_code(
|
||||
frame=frame,
|
||||
user_id=str(user_id or ""),
|
||||
code=connect_code,
|
||||
)
|
||||
if handled:
|
||||
return
|
||||
|
||||
inbound_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
|
||||
inbound = self._make_inbound(
|
||||
chat_id=user_id, # keep user's conversation in memory
|
||||
@@ -292,8 +305,52 @@ class WeComChannel(Channel):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
await self.bus.publish_inbound(inbound)
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="wecom",
|
||||
workspace_id=str(inbound.metadata.get("aibotid") or "") or None,
|
||||
fallback_without_workspace=True,
|
||||
)
|
||||
|
||||
async def _bind_connection_from_connect_code(self, *, frame: dict[str, Any], user_id: str, code: str) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="wecom", state=code)
|
||||
if state is None:
|
||||
await self._send_connection_reply(frame, "WeCom connection code is invalid or expired.")
|
||||
return True
|
||||
|
||||
if not user_id:
|
||||
await self._send_connection_reply(frame, "WeCom connection could not be completed from this message.")
|
||||
return True
|
||||
|
||||
body = frame.get("body", {}) or {}
|
||||
workspace_id = str(body.get("aibotid") or "") or None
|
||||
await self._connection_repo.upsert_connection(
|
||||
owner_user_id=state["owner_user_id"],
|
||||
provider="wecom",
|
||||
external_account_id=user_id,
|
||||
workspace_id=workspace_id,
|
||||
metadata={
|
||||
"aibotid": workspace_id,
|
||||
"chattype": body.get("chattype"),
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
await self._send_connection_reply(frame, "WeCom connected to DeerFlow.")
|
||||
return True
|
||||
|
||||
async def _send_connection_reply(self, frame: dict[str, Any], text: str) -> None:
|
||||
if not self._ws_client:
|
||||
return
|
||||
await self._ws_client.reply(frame, {"msgtype": "text", "text": {"content": text}})
|
||||
|
||||
async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||||
if not self._ws_client:
|
||||
return
|
||||
|
||||
@@ -16,6 +16,7 @@ from app.gateway.routers import (
|
||||
artifacts,
|
||||
assistants_compat,
|
||||
auth,
|
||||
channel_connections,
|
||||
channels,
|
||||
feedback,
|
||||
mcp,
|
||||
@@ -384,6 +385,9 @@ This gateway provides runtime endpoints for agent runs plus custom endpoints for
|
||||
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
|
||||
app.include_router(suggestions.router)
|
||||
|
||||
# User-facing IM channel connection API is mounted at /api/channels
|
||||
app.include_router(channel_connections.router)
|
||||
|
||||
# Channels API is mounted at /api/channels
|
||||
app.include_router(channels.router)
|
||||
|
||||
|
||||
@@ -6,9 +6,11 @@ import logging
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
|
||||
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
||||
|
||||
AUTH_DISABLED_ENV_VAR = "DEER_FLOW_AUTH_DISABLED"
|
||||
AUTH_DISABLED_USER_ID = "e2e-user"
|
||||
AUTH_DISABLED_USER_EMAIL = "e2e@test.local"
|
||||
AUTH_DISABLED_USER_ID = DEFAULT_USER_ID
|
||||
AUTH_DISABLED_USER_EMAIL = "default@test.local"
|
||||
|
||||
AUTH_SOURCE_SESSION = "session"
|
||||
AUTH_SOURCE_INTERNAL = "internal"
|
||||
|
||||
@@ -276,6 +276,8 @@ def require_permission(
|
||||
# strict-deny rather than strict-allow — only an *existing*
|
||||
# row with a *different* user_id triggers 404.
|
||||
if owner_check:
|
||||
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
|
||||
|
||||
thread_id = kwargs.get("thread_id")
|
||||
if thread_id is None:
|
||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||
@@ -288,6 +290,22 @@ def require_permission(
|
||||
str(auth.user.id),
|
||||
require_existing=require_existing,
|
||||
)
|
||||
if not allowed and getattr(auth.user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
|
||||
# Trusted internal callers (channel workers) also act for
|
||||
# the connection owner carried in X-DeerFlow-Owner-User-Id.
|
||||
# Scope the check to that owner instead of bypassing it; a
|
||||
# leaked internal token must not grant cross-user thread
|
||||
# access. The header is honored only after ``auth`` proved
|
||||
# the caller holds the internal token (mirrors
|
||||
# get_trusted_internal_owner_user_id, which keys off the
|
||||
# middleware-stamped ``request.state.user``).
|
||||
header_owner = (request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME) or "").strip()
|
||||
if header_owner:
|
||||
allowed = await thread_store.check_access(
|
||||
thread_id,
|
||||
header_owner,
|
||||
require_existing=require_existing,
|
||||
)
|
||||
if not allowed:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
|
||||
@@ -5,10 +5,12 @@ from __future__ import annotations
|
||||
import os
|
||||
import secrets
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
||||
|
||||
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
|
||||
INTERNAL_OWNER_USER_ID_HEADER_NAME = "X-DeerFlow-Owner-User-Id"
|
||||
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
|
||||
INTERNAL_SYSTEM_ROLE = "internal"
|
||||
|
||||
@@ -23,9 +25,12 @@ def _load_internal_auth_token() -> str:
|
||||
_INTERNAL_AUTH_TOKEN = _load_internal_auth_token()
|
||||
|
||||
|
||||
def create_internal_auth_headers() -> dict[str, str]:
|
||||
def create_internal_auth_headers(*, owner_user_id: str | None = None) -> dict[str, str]:
|
||||
"""Return headers that authenticate trusted Gateway internal calls."""
|
||||
return {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
|
||||
headers = {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
|
||||
if owner_user_id:
|
||||
headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] = owner_user_id
|
||||
return headers
|
||||
|
||||
|
||||
def is_valid_internal_auth_token(token: str | None) -> bool:
|
||||
@@ -36,3 +41,21 @@ def is_valid_internal_auth_token(token: str | None) -> bool:
|
||||
def get_internal_user():
|
||||
"""Return the synthetic user used for trusted internal channel calls."""
|
||||
return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE)
|
||||
|
||||
|
||||
def get_trusted_internal_owner_user_id(request: Any) -> str | None:
|
||||
"""Return the owner override for a trusted internal request, if present.
|
||||
|
||||
The header is ignored for normal browser/API callers. It is only honored
|
||||
after ``AuthMiddleware`` has validated the internal auth token and stamped
|
||||
the synthetic internal user onto ``request.state.user``.
|
||||
"""
|
||||
user = getattr(getattr(request, "state", None), "user", None)
|
||||
if getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE:
|
||||
return None
|
||||
|
||||
owner_user_id = request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME)
|
||||
if not owner_user_id:
|
||||
return None
|
||||
owner_user_id = owner_user_id.strip()
|
||||
return owner_user_id or None
|
||||
|
||||
@@ -0,0 +1,670 @@
|
||||
"""Browser-facing APIs for user-owned IM channel bindings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, Response
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.channels.runtime_config_store import (
|
||||
ChannelRuntimeConfigStore,
|
||||
apply_runtime_connection_config,
|
||||
merge_runtime_channel_configs,
|
||||
)
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
|
||||
router = APIRouter(prefix="/api/channels", tags=["channel-connections"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_STATE_TTL_SECONDS = 600
|
||||
_MASKED_CREDENTIAL_VALUE = "********"
|
||||
|
||||
|
||||
class ChannelCredentialFieldResponse(BaseModel):
|
||||
name: str
|
||||
label: str
|
||||
type: str = "text"
|
||||
required: bool = True
|
||||
|
||||
|
||||
class ChannelProviderResponse(BaseModel):
|
||||
provider: str
|
||||
display_name: str
|
||||
enabled: bool
|
||||
configured: bool
|
||||
connectable: bool
|
||||
unavailable_reason: str | None = None
|
||||
auth_mode: str
|
||||
connection_status: str
|
||||
credential_fields: list[ChannelCredentialFieldResponse] = Field(default_factory=list)
|
||||
credential_values: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ChannelProvidersResponse(BaseModel):
|
||||
enabled: bool
|
||||
providers: list[ChannelProviderResponse]
|
||||
|
||||
|
||||
class ChannelConnectionResponse(BaseModel):
|
||||
id: str
|
||||
provider: str
|
||||
status: str
|
||||
external_account_id: str | None = None
|
||||
external_account_name: str | None = None
|
||||
workspace_id: str | None = None
|
||||
workspace_name: str | None = None
|
||||
scopes: list[str] = Field(default_factory=list)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ChannelConnectionsResponse(BaseModel):
|
||||
connections: list[ChannelConnectionResponse]
|
||||
|
||||
|
||||
class ChannelConnectResponse(BaseModel):
|
||||
provider: str
|
||||
mode: str
|
||||
url: str | None = None
|
||||
code: str
|
||||
instruction: str
|
||||
expires_in: int
|
||||
|
||||
|
||||
class ChannelRuntimeConfigRequest(BaseModel):
|
||||
values: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
_PROVIDER_META: dict[str, dict[str, str]] = {
|
||||
"telegram": {"display_name": "Telegram", "auth_mode": "deep_link"},
|
||||
"slack": {"display_name": "Slack", "auth_mode": "binding_code"},
|
||||
"discord": {"display_name": "Discord", "auth_mode": "binding_code"},
|
||||
"feishu": {"display_name": "Feishu", "auth_mode": "binding_code"},
|
||||
"dingtalk": {"display_name": "DingTalk", "auth_mode": "binding_code"},
|
||||
"wechat": {"display_name": "WeChat", "auth_mode": "binding_code"},
|
||||
"wecom": {"display_name": "WeCom", "auth_mode": "binding_code"},
|
||||
}
|
||||
|
||||
_CREDENTIAL_FIELDS: dict[str, tuple[dict[str, str], ...]] = {
|
||||
"telegram": (
|
||||
{"name": "bot_token", "label": "Bot token", "type": "password"},
|
||||
{"name": "bot_username", "label": "Bot username", "type": "text"},
|
||||
),
|
||||
"slack": (
|
||||
{"name": "bot_token", "label": "Bot token", "type": "password"},
|
||||
{"name": "app_token", "label": "App token", "type": "password"},
|
||||
),
|
||||
"discord": ({"name": "bot_token", "label": "Bot token", "type": "password"},),
|
||||
"feishu": (
|
||||
{"name": "app_id", "label": "App ID", "type": "text"},
|
||||
{"name": "app_secret", "label": "App secret", "type": "password"},
|
||||
),
|
||||
"dingtalk": (
|
||||
{"name": "client_id", "label": "Client ID", "type": "text"},
|
||||
{"name": "client_secret", "label": "Client secret", "type": "password"},
|
||||
),
|
||||
"wechat": ({"name": "bot_token", "label": "Bot token", "type": "password"},),
|
||||
"wecom": (
|
||||
{"name": "bot_id", "label": "Bot ID", "type": "text"},
|
||||
{"name": "bot_secret", "label": "Bot secret", "type": "password"},
|
||||
),
|
||||
}
|
||||
|
||||
_RUNTIME_REQUIREMENTS: dict[str, tuple[str, ...]] = {
|
||||
"telegram": ("bot_token",),
|
||||
"slack": ("bot_token", "app_token"),
|
||||
"discord": ("bot_token",),
|
||||
"feishu": ("app_id", "app_secret"),
|
||||
"dingtalk": ("client_id", "client_secret"),
|
||||
"wechat": ("bot_token",),
|
||||
"wecom": ("bot_id", "bot_secret"),
|
||||
}
|
||||
|
||||
|
||||
def _get_user_id(request: Request) -> str:
|
||||
user = getattr(request.state, "user", None)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
return str(user.id)
|
||||
|
||||
|
||||
async def _require_admin_user(request: Request) -> None:
|
||||
"""Require an admin caller for instance-wide channel runtime mutations.
|
||||
|
||||
Runtime credentials and the channel workers they start/stop are shared by
|
||||
every user of the deployment, so only admins may change them (same model
|
||||
as the MCP config API). Auth-disabled local mode uses a synthetic admin
|
||||
user and is unaffected.
|
||||
"""
|
||||
user = getattr(request.state, "user", None)
|
||||
if user is None:
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
user = await get_current_user_from_request(request)
|
||||
|
||||
if getattr(user, "system_role", None) != "admin":
|
||||
raise HTTPException(status_code=403, detail="Admin privileges required to manage channel runtime credentials.")
|
||||
|
||||
|
||||
def _get_app_config():
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
return get_app_config()
|
||||
|
||||
|
||||
async def _get_runtime_config_store(request: Request) -> ChannelRuntimeConfigStore:
|
||||
store = getattr(request.app.state, "channel_runtime_config_store", None)
|
||||
if isinstance(store, ChannelRuntimeConfigStore):
|
||||
return store
|
||||
# Constructing the store reads its JSON file from disk; keep it off the
|
||||
# event loop.
|
||||
store = await asyncio.to_thread(ChannelRuntimeConfigStore)
|
||||
request.app.state.channel_runtime_config_store = store
|
||||
return store
|
||||
|
||||
|
||||
async def _get_channel_connections_config(request: Request) -> ChannelConnectionsConfig:
|
||||
config = getattr(request.app.state, "channel_connections_config", None)
|
||||
if not isinstance(config, ChannelConnectionsConfig):
|
||||
config = _get_app_config().channel_connections
|
||||
config = apply_runtime_connection_config(config, store=await _get_runtime_config_store(request))
|
||||
request.app.state.channel_connections_config = config
|
||||
return config
|
||||
|
||||
|
||||
async def _get_channels_config(request: Request) -> dict[str, Any]:
|
||||
state_config = getattr(request.app.state, "channels_config", None)
|
||||
if isinstance(state_config, dict):
|
||||
return state_config
|
||||
|
||||
result = await _load_channels_config(request, await _get_channel_connections_config(request))
|
||||
request.app.state.channels_config = result
|
||||
return result
|
||||
|
||||
|
||||
async def _load_channels_config(request: Request, config: ChannelConnectionsConfig) -> dict[str, Any]:
|
||||
app_config = _get_app_config()
|
||||
extra = app_config.model_extra or {}
|
||||
channels_config = extra.get("channels")
|
||||
result = dict(channels_config) if isinstance(channels_config, dict) else {}
|
||||
merge_runtime_channel_configs(
|
||||
result,
|
||||
config,
|
||||
store=await _get_runtime_config_store(request),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _get_repository(request: Request, config: ChannelConnectionsConfig) -> ChannelConnectionRepository:
|
||||
repo = getattr(request.app.state, "channel_connection_repo", None)
|
||||
if isinstance(repo, ChannelConnectionRepository):
|
||||
return repo
|
||||
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
raise HTTPException(status_code=503, detail="Channel connection persistence is not available")
|
||||
|
||||
repo = ChannelConnectionRepository(sf)
|
||||
request.app.state.channel_connection_repo = repo
|
||||
return repo
|
||||
|
||||
|
||||
def _provider_config(config: ChannelConnectionsConfig, provider: str):
|
||||
provider_config = getattr(config, provider, None)
|
||||
if provider_config is None:
|
||||
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||
return provider_config
|
||||
|
||||
|
||||
def _runtime_channel_configured(provider: str, channels_config: dict[str, Any]) -> bool:
|
||||
runtime_config = channels_config.get(provider)
|
||||
if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False):
|
||||
return False
|
||||
return all(str(runtime_config.get(key) or "").strip() for key in _RUNTIME_REQUIREMENTS[provider])
|
||||
|
||||
|
||||
def _runtime_unavailable_reason(provider: str) -> str:
|
||||
meta = _PROVIDER_META.get(provider)
|
||||
display_name = meta["display_name"] if meta else provider
|
||||
return f"Enter the required {display_name} credentials to connect this channel."
|
||||
|
||||
|
||||
def _runtime_not_running_reason(provider: str) -> str:
|
||||
meta = _PROVIDER_META.get(provider)
|
||||
display_name = meta["display_name"] if meta else provider
|
||||
return f"{display_name} channel is configured but is not running. Check the credentials and service logs."
|
||||
|
||||
|
||||
def _runtime_channel_running(provider: str) -> bool | None:
|
||||
try:
|
||||
from app.channels.service import get_channel_service
|
||||
except Exception:
|
||||
logger.debug("Unable to inspect channel service status", exc_info=True)
|
||||
return None
|
||||
|
||||
service = get_channel_service()
|
||||
if service is None:
|
||||
return None
|
||||
try:
|
||||
status = service.get_status()
|
||||
except Exception:
|
||||
logger.debug("Unable to read channel service status", exc_info=True)
|
||||
return None
|
||||
|
||||
if not status.get("service_running"):
|
||||
return False
|
||||
channel_status = status.get("channels", {}).get(provider)
|
||||
if not isinstance(channel_status, dict):
|
||||
return None
|
||||
return bool(channel_status.get("running"))
|
||||
|
||||
|
||||
async def _ensure_runtime_channel_ready_if_available(
|
||||
provider: str,
|
||||
channels_config: dict[str, Any],
|
||||
) -> bool | None:
|
||||
runtime_config = channels_config.get(provider)
|
||||
if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False):
|
||||
return None
|
||||
|
||||
try:
|
||||
from app.channels.service import get_channel_service
|
||||
except Exception:
|
||||
logger.debug("Unable to import channel service for readiness reconciliation", exc_info=True)
|
||||
return None
|
||||
|
||||
service = get_channel_service()
|
||||
if service is None:
|
||||
return None
|
||||
|
||||
ensure_channel_ready = getattr(service, "ensure_channel_ready", None)
|
||||
if ensure_channel_ready is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return await ensure_channel_ready(provider, runtime_config)
|
||||
except Exception:
|
||||
logger.exception("Failed to reconcile runtime channel readiness")
|
||||
return False
|
||||
|
||||
|
||||
def _provider_unavailable_reason(
|
||||
config: ChannelConnectionsConfig,
|
||||
channels_config: dict[str, Any],
|
||||
provider: str,
|
||||
) -> str | None:
|
||||
provider_config = _provider_config(config, provider)
|
||||
if not provider_config.enabled:
|
||||
return None
|
||||
if not provider_config.configured:
|
||||
return _runtime_unavailable_reason(provider)
|
||||
if not _runtime_channel_configured(provider, channels_config):
|
||||
return _runtime_unavailable_reason(provider)
|
||||
if _runtime_channel_running(provider) is False:
|
||||
return _runtime_not_running_reason(provider)
|
||||
return None
|
||||
|
||||
|
||||
def _provider_status(
|
||||
config: ChannelConnectionsConfig,
|
||||
channels_config: dict[str, Any],
|
||||
provider: str,
|
||||
) -> tuple[dict[str, bool], str | None]:
|
||||
declared = config.provider_status(provider)
|
||||
unavailable_reason = _provider_unavailable_reason(config, channels_config, provider)
|
||||
configured = declared["configured"] and _runtime_channel_configured(provider, channels_config)
|
||||
return {"enabled": declared["enabled"], "configured": configured}, unavailable_reason
|
||||
|
||||
|
||||
def _new_binding_code() -> str:
|
||||
return secrets.token_urlsafe(16)
|
||||
|
||||
|
||||
async def _create_state(
|
||||
repo: ChannelConnectionRepository,
|
||||
*,
|
||||
owner_user_id: str,
|
||||
provider: str,
|
||||
) -> str:
|
||||
state = _new_binding_code()
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id=owner_user_id,
|
||||
provider=provider,
|
||||
state=state,
|
||||
expires_at=datetime.now(UTC) + timedelta(seconds=_STATE_TTL_SECONDS),
|
||||
)
|
||||
return state
|
||||
|
||||
|
||||
def _connect_instruction(provider: str, code: str) -> str:
|
||||
if provider == "telegram":
|
||||
return f"Send /start {code} to the DeerFlow Telegram bot."
|
||||
meta = _PROVIDER_META.get(provider)
|
||||
if meta is None:
|
||||
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||
return f"Send /connect {code} to the DeerFlow {meta['display_name']} bot."
|
||||
|
||||
|
||||
def _connect_url(config: ChannelConnectionsConfig, provider: str, code: str) -> str | None:
|
||||
if provider == "telegram":
|
||||
provider_config = _provider_config(config, provider)
|
||||
return f"https://t.me/{provider_config.bot_username}?start={code}"
|
||||
if _PROVIDER_META.get(provider, {}).get("auth_mode") == "binding_code":
|
||||
return None
|
||||
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||
|
||||
|
||||
def _connection_updated_at(connection: dict[str, Any]) -> datetime:
|
||||
value = connection.get("updated_at")
|
||||
if isinstance(value, datetime):
|
||||
return value if value.tzinfo is not None else value.replace(tzinfo=UTC)
|
||||
if isinstance(value, str) and value:
|
||||
try:
|
||||
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
pass
|
||||
return datetime.min.replace(tzinfo=UTC)
|
||||
|
||||
|
||||
def _newest_connection_by_provider(connections: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||
by_provider: dict[str, dict[str, Any]] = {}
|
||||
for item in connections:
|
||||
existing = by_provider.get(item["provider"])
|
||||
if existing is None or _connection_updated_at(item) > _connection_updated_at(existing):
|
||||
by_provider[item["provider"]] = item
|
||||
return by_provider
|
||||
|
||||
|
||||
def _credential_fields(provider: str) -> list[ChannelCredentialFieldResponse]:
|
||||
fields = _CREDENTIAL_FIELDS.get(provider)
|
||||
if fields is None:
|
||||
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||
return [ChannelCredentialFieldResponse(**field) for field in fields]
|
||||
|
||||
|
||||
def _credential_values(provider: str, channels_config: dict[str, Any]) -> dict[str, str]:
|
||||
runtime_config = channels_config.get(provider)
|
||||
if not isinstance(runtime_config, dict):
|
||||
return {}
|
||||
|
||||
values: dict[str, str] = {}
|
||||
for field in _credential_fields(provider):
|
||||
value = str(runtime_config.get(field.name) or "").strip()
|
||||
if not value:
|
||||
continue
|
||||
values[field.name] = _MASKED_CREDENTIAL_VALUE if field.type == "password" else value
|
||||
return values
|
||||
|
||||
|
||||
def _provider_response(
|
||||
config: ChannelConnectionsConfig,
|
||||
channels_config: dict[str, Any],
|
||||
provider: str,
|
||||
meta: dict[str, str],
|
||||
connection: dict[str, Any] | None = None,
|
||||
) -> ChannelProviderResponse:
|
||||
from app.gateway.auth_disabled import is_auth_disabled
|
||||
|
||||
status, unavailable_reason = _provider_status(config, channels_config, provider)
|
||||
if connection:
|
||||
connection_status = connection["status"]
|
||||
elif is_auth_disabled() and status["configured"] and unavailable_reason is None:
|
||||
# Auth-disabled local mode routes every channel message to the default
|
||||
# user, so a configured running channel needs no per-user binding.
|
||||
connection_status = "connected"
|
||||
else:
|
||||
connection_status = "not_connected"
|
||||
credential_values = _credential_values(provider, channels_config)
|
||||
if provider == "telegram" and not credential_values.get("bot_username"):
|
||||
bot_username = str(_provider_config(config, provider).bot_username or "").strip()
|
||||
if bot_username:
|
||||
credential_values["bot_username"] = bot_username
|
||||
return ChannelProviderResponse(
|
||||
provider=provider,
|
||||
display_name=meta["display_name"],
|
||||
enabled=status["enabled"],
|
||||
configured=status["configured"],
|
||||
connectable=status["enabled"] and status["configured"] and unavailable_reason is None,
|
||||
unavailable_reason=unavailable_reason,
|
||||
auth_mode=meta["auth_mode"],
|
||||
connection_status=connection_status,
|
||||
credential_fields=_credential_fields(provider),
|
||||
credential_values=credential_values,
|
||||
)
|
||||
|
||||
|
||||
def _required_runtime_values(
|
||||
provider: str,
|
||||
values: dict[str, str],
|
||||
existing_config: dict[str, Any] | None = None,
|
||||
) -> dict[str, str]:
|
||||
fields = _credential_fields(provider)
|
||||
cleaned: dict[str, str] = {}
|
||||
missing: list[str] = []
|
||||
existing_config = existing_config or {}
|
||||
for field in fields:
|
||||
raw_value = values.get(field.name, "")
|
||||
if field.type == "password" and raw_value == _MASKED_CREDENTIAL_VALUE:
|
||||
existing_value = str(existing_config.get(field.name) or "").strip()
|
||||
if existing_value:
|
||||
cleaned[field.name] = existing_value
|
||||
continue
|
||||
value = raw_value.strip() if isinstance(raw_value, str) else str(raw_value or "").strip()
|
||||
if field.required and not value:
|
||||
missing.append(field.label)
|
||||
cleaned[field.name] = value
|
||||
if missing:
|
||||
raise HTTPException(status_code=400, detail=f"Missing required channel configuration: {', '.join(missing)}")
|
||||
return cleaned
|
||||
|
||||
|
||||
async def _restart_runtime_channel_if_available(provider: str, runtime_config: dict[str, Any]) -> bool | None:
|
||||
try:
|
||||
from app.channels.service import get_channel_service
|
||||
except Exception:
|
||||
logger.exception("Failed to import channel service while configuring a runtime channel")
|
||||
return None
|
||||
|
||||
service = get_channel_service()
|
||||
if service is None:
|
||||
return None
|
||||
return await service.configure_channel(provider, runtime_config)
|
||||
|
||||
|
||||
async def _sync_runtime_channel_after_removal(provider: str, channels_config: dict[str, Any]) -> bool | None:
|
||||
try:
|
||||
from app.channels.service import get_channel_service
|
||||
except Exception:
|
||||
logger.exception("Failed to import channel service while disconnecting a runtime channel")
|
||||
return None
|
||||
|
||||
service = get_channel_service()
|
||||
if service is None:
|
||||
return None
|
||||
|
||||
runtime_config = channels_config.get(provider)
|
||||
if isinstance(runtime_config, dict) and runtime_config.get("enabled", False):
|
||||
return await service.configure_channel(provider, runtime_config)
|
||||
return await service.remove_channel(provider)
|
||||
|
||||
|
||||
@router.get("/providers", response_model=ChannelProvidersResponse)
|
||||
async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
||||
config = await _get_channel_connections_config(request)
|
||||
channels_config = await _get_channels_config(request)
|
||||
repo = None
|
||||
if config.enabled:
|
||||
try:
|
||||
repo = _get_repository(request, config)
|
||||
except HTTPException as exc:
|
||||
if exc.status_code != 503:
|
||||
raise
|
||||
owner_user_id = _get_user_id(request)
|
||||
connections = await repo.list_connections(owner_user_id) if repo is not None else []
|
||||
by_provider = _newest_connection_by_provider(connections)
|
||||
|
||||
enabled_providers = [provider for provider in _PROVIDER_META if config.provider_status(provider)["enabled"]]
|
||||
# Readiness reconciliation is independent per provider; run it
|
||||
# concurrently so one slow channel restart does not serialize the
|
||||
# whole /providers response.
|
||||
await asyncio.gather(
|
||||
*(_ensure_runtime_channel_ready_if_available(provider, channels_config) for provider in enabled_providers if _runtime_channel_configured(provider, channels_config)),
|
||||
)
|
||||
|
||||
providers: list[ChannelProviderResponse] = []
|
||||
for provider in enabled_providers:
|
||||
connection = by_provider.get(provider)
|
||||
providers.append(_provider_response(config, channels_config, provider, _PROVIDER_META[provider], connection))
|
||||
return ChannelProvidersResponse(enabled=config.enabled, providers=providers)
|
||||
|
||||
|
||||
@router.get("/connections", response_model=ChannelConnectionsResponse)
|
||||
async def get_channel_connections(request: Request) -> ChannelConnectionsResponse:
|
||||
config = await _get_channel_connections_config(request)
|
||||
if not config.enabled:
|
||||
return ChannelConnectionsResponse(connections=[])
|
||||
repo = _get_repository(request, config)
|
||||
rows = await repo.list_connections(_get_user_id(request))
|
||||
return ChannelConnectionsResponse(connections=[ChannelConnectionResponse(**row) for row in rows])
|
||||
|
||||
|
||||
@router.delete("/connections/{connection_id}", status_code=204)
|
||||
async def disconnect_channel_connection(connection_id: str, request: Request) -> Response:
|
||||
config = await _get_channel_connections_config(request)
|
||||
if not config.enabled:
|
||||
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||
|
||||
repo = _get_repository(request, config)
|
||||
disconnected = await repo.disconnect_connection(
|
||||
connection_id=connection_id,
|
||||
owner_user_id=_get_user_id(request),
|
||||
)
|
||||
if not disconnected:
|
||||
raise HTTPException(status_code=404, detail="Channel connection not found")
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.delete("/{provider}/runtime-config", response_model=ChannelProviderResponse)
|
||||
async def disconnect_channel_provider_runtime(provider: str, request: Request) -> ChannelProviderResponse:
|
||||
await _require_admin_user(request)
|
||||
config = await _get_channel_connections_config(request)
|
||||
if not config.enabled:
|
||||
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||
|
||||
provider_config = _provider_config(config, provider)
|
||||
if not provider_config.enabled:
|
||||
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
||||
|
||||
owner_user_id = _get_user_id(request)
|
||||
try:
|
||||
repo = _get_repository(request, config)
|
||||
except HTTPException as exc:
|
||||
if exc.status_code != 503:
|
||||
raise
|
||||
repo = None
|
||||
|
||||
if repo is not None:
|
||||
for connection in await repo.list_connections(owner_user_id):
|
||||
if connection["provider"] == provider and connection["status"] != "revoked":
|
||||
await repo.disconnect_connection(
|
||||
connection_id=connection["id"],
|
||||
owner_user_id=owner_user_id,
|
||||
)
|
||||
|
||||
store = await _get_runtime_config_store(request)
|
||||
await asyncio.to_thread(store.set_provider_disconnected, provider)
|
||||
channels_config = await _load_channels_config(request, config)
|
||||
request.app.state.channels_config = channels_config
|
||||
|
||||
stopped = await _sync_runtime_channel_after_removal(provider, channels_config)
|
||||
if stopped is False:
|
||||
display_name = _PROVIDER_META[provider]["display_name"]
|
||||
raise HTTPException(status_code=400, detail=f"Failed to stop {display_name} channel. Try again.")
|
||||
|
||||
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
|
||||
|
||||
|
||||
@router.post("/{provider}/connect", response_model=ChannelConnectResponse)
|
||||
async def connect_channel_provider(provider: str, request: Request) -> ChannelConnectResponse:
|
||||
config = await _get_channel_connections_config(request)
|
||||
channels_config = await _get_channels_config(request)
|
||||
if not config.enabled:
|
||||
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||
|
||||
provider_config = _provider_config(config, provider)
|
||||
if provider_config.enabled and _runtime_channel_configured(provider, channels_config):
|
||||
await _ensure_runtime_channel_ready_if_available(provider, channels_config)
|
||||
|
||||
status, unavailable_reason = _provider_status(config, channels_config, provider)
|
||||
if not status["enabled"]:
|
||||
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
||||
if unavailable_reason:
|
||||
raise HTTPException(status_code=400, detail=unavailable_reason)
|
||||
if not status["configured"]:
|
||||
raise HTTPException(status_code=400, detail="Channel provider is not configured")
|
||||
|
||||
repo = _get_repository(request, config)
|
||||
code = await _create_state(
|
||||
repo,
|
||||
owner_user_id=_get_user_id(request),
|
||||
provider=provider,
|
||||
)
|
||||
return ChannelConnectResponse(
|
||||
provider=provider,
|
||||
mode=_PROVIDER_META[provider]["auth_mode"],
|
||||
url=_connect_url(config, provider, code),
|
||||
code=code,
|
||||
instruction=_connect_instruction(provider, code),
|
||||
expires_in=_STATE_TTL_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{provider}/runtime-config", response_model=ChannelProviderResponse)
|
||||
async def configure_channel_provider_runtime(
|
||||
provider: str,
|
||||
body: ChannelRuntimeConfigRequest,
|
||||
request: Request,
|
||||
) -> ChannelProviderResponse:
|
||||
await _require_admin_user(request)
|
||||
config = await _get_channel_connections_config(request)
|
||||
if not config.enabled:
|
||||
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||
|
||||
provider_config = _provider_config(config, provider)
|
||||
if not provider_config.enabled:
|
||||
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
||||
|
||||
channels_config = await _get_channels_config(request)
|
||||
existing = channels_config.get(provider)
|
||||
runtime_config = dict(existing) if isinstance(existing, dict) else {}
|
||||
values = _required_runtime_values(provider, body.values, runtime_config)
|
||||
runtime_config["enabled"] = True
|
||||
|
||||
for key in _RUNTIME_REQUIREMENTS[provider]:
|
||||
runtime_config[key] = values[key]
|
||||
|
||||
if provider == "telegram":
|
||||
# The deep-link username is persisted with the runtime channel config
|
||||
# (set_provider_config below) and applied to future requests via
|
||||
# apply_runtime_connection_config; never mutate the config instance
|
||||
# cached by get_app_config().
|
||||
runtime_config["bot_username"] = values["bot_username"]
|
||||
|
||||
channels_config[provider] = runtime_config
|
||||
request.app.state.channels_config = channels_config
|
||||
|
||||
started = await _restart_runtime_channel_if_available(provider, runtime_config)
|
||||
if started is False:
|
||||
display_name = _PROVIDER_META[provider]["display_name"]
|
||||
raise HTTPException(status_code=400, detail=f"Failed to start {display_name} channel. Check the values and try again.")
|
||||
|
||||
store = await _get_runtime_config_store(request)
|
||||
await asyncio.to_thread(store.set_provider_config, provider, runtime_config)
|
||||
|
||||
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
|
||||
@@ -22,6 +22,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer
|
||||
from app.gateway.internal_auth import get_trusted_internal_owner_user_id
|
||||
from app.gateway.utils import sanitize_log_param
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
@@ -257,11 +258,19 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
thread_store = get_thread_store(request)
|
||||
thread_id = body.thread_id or str(uuid.uuid4())
|
||||
now = now_iso()
|
||||
thread_owner_user_id = get_trusted_internal_owner_user_id(request)
|
||||
thread_owner_kwargs = {"user_id": thread_owner_user_id} if thread_owner_user_id else {}
|
||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
||||
|
||||
# Idempotency: return existing record when already present
|
||||
existing_record = await thread_store.get(thread_id)
|
||||
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
|
||||
if existing_record is None and thread_owner_user_id:
|
||||
unscoped_record = await thread_store.get(thread_id, user_id=None)
|
||||
if unscoped_record is not None:
|
||||
if unscoped_record.get("user_id") != thread_owner_user_id:
|
||||
await thread_store.update_owner(thread_id, thread_owner_user_id, user_id=None)
|
||||
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
|
||||
if existing_record is not None:
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
@@ -276,6 +285,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
await thread_store.create(
|
||||
thread_id,
|
||||
assistant_id=getattr(body, "assistant_id", None),
|
||||
**thread_owner_kwargs,
|
||||
metadata=body.metadata,
|
||||
)
|
||||
except Exception:
|
||||
|
||||
@@ -12,6 +12,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
@@ -19,7 +20,7 @@ from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages.utils import convert_to_messages
|
||||
|
||||
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
|
||||
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
|
||||
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE, get_trusted_internal_owner_user_id
|
||||
from app.gateway.utils import sanitize_log_param
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.runtime import (
|
||||
@@ -35,6 +36,7 @@ from deerflow.runtime import (
|
||||
run_agent,
|
||||
)
|
||||
from deerflow.runtime.runs.naming import resolve_root_run_name
|
||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -315,6 +317,7 @@ async def start_run(
|
||||
detail=f"Model {model_name!r} is not in the configured model allowlist",
|
||||
)
|
||||
|
||||
owner_user_id = get_trusted_internal_owner_user_id(request)
|
||||
# Stateless run endpoints carry thread_id in the request *body*, so the
|
||||
# @require_permission(owner_check=True) decorator -- which resolves ownership
|
||||
# from the path param -- cannot protect them. Enforce thread ownership here,
|
||||
@@ -323,79 +326,99 @@ async def start_run(
|
||||
# temp threads) and NULL-owner rows (shared / pre-auth data) stay accessible
|
||||
# via check_access; only a thread already owned by another user is rejected
|
||||
# with 404, matching thread_runs.py's anti-enumeration behaviour. Internal
|
||||
# channel runs act on behalf of IM users they do not own (see
|
||||
# inject_authenticated_user_context), so the internal system role is exempt.
|
||||
# channel runs act on behalf of the connection owner carried in
|
||||
# X-DeerFlow-Owner-User-Id, so they are scoped to that owner instead of
|
||||
# bypassing the check -- a leaked internal token must not grant cross-user
|
||||
# thread access.
|
||||
user = getattr(request.state, "user", None)
|
||||
if user is not None and getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE:
|
||||
if not await run_ctx.thread_store.check_access(thread_id, str(user.id)):
|
||||
if user is not None:
|
||||
allowed = await run_ctx.thread_store.check_access(thread_id, str(user.id))
|
||||
if not allowed and owner_user_id and getattr(user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
|
||||
# Channel workers may also act for the connection owner named in
|
||||
# the trusted header (e.g. claiming a legacy default-owned channel
|
||||
# thread for its real owner).
|
||||
allowed = await run_ctx.thread_store.check_access(thread_id, owner_user_id)
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
owner_context_token = set_current_user(SimpleNamespace(id=owner_user_id)) if owner_user_id else None
|
||||
try:
|
||||
record = await run_mgr.create_or_reject(
|
||||
thread_id,
|
||||
body.assistant_id,
|
||||
on_disconnect=disconnect,
|
||||
metadata=body.metadata or {},
|
||||
kwargs={"input": body.input, "config": body.config},
|
||||
multitask_strategy=body.multitask_strategy,
|
||||
model_name=model_name,
|
||||
)
|
||||
except ConflictError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||
except UnsupportedStrategyError as exc:
|
||||
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
||||
|
||||
# Upsert thread metadata so the thread appears in /threads/search,
|
||||
# even for threads that were never explicitly created via POST /threads
|
||||
# (e.g. stateless runs).
|
||||
try:
|
||||
existing = await run_ctx.thread_store.get(thread_id)
|
||||
if existing is None:
|
||||
await run_ctx.thread_store.create(
|
||||
try:
|
||||
record = await run_mgr.create_or_reject(
|
||||
thread_id,
|
||||
assistant_id=body.assistant_id,
|
||||
metadata=body.metadata,
|
||||
body.assistant_id,
|
||||
on_disconnect=disconnect,
|
||||
metadata=body.metadata or {},
|
||||
kwargs={"input": body.input, "config": body.config},
|
||||
multitask_strategy=body.multitask_strategy,
|
||||
model_name=model_name,
|
||||
user_id=owner_user_id,
|
||||
)
|
||||
else:
|
||||
await run_ctx.thread_store.update_status(thread_id, "running")
|
||||
except Exception:
|
||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||
except ConflictError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||
except UnsupportedStrategyError as exc:
|
||||
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
||||
|
||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
||||
graph_input = normalize_input(body.input)
|
||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||
# Upsert thread metadata so the thread appears in /threads/search,
|
||||
# even for threads that were never explicitly created via POST /threads
|
||||
# (e.g. stateless runs).
|
||||
try:
|
||||
existing = await run_ctx.thread_store.get(thread_id)
|
||||
if existing is None and owner_user_id:
|
||||
unscoped_existing = await run_ctx.thread_store.get(thread_id, user_id=None)
|
||||
if unscoped_existing is not None:
|
||||
if unscoped_existing.get("user_id") != owner_user_id:
|
||||
await run_ctx.thread_store.update_owner(thread_id, owner_user_id, user_id=None)
|
||||
existing = await run_ctx.thread_store.get(thread_id)
|
||||
if existing is None:
|
||||
await run_ctx.thread_store.create(
|
||||
thread_id,
|
||||
assistant_id=body.assistant_id,
|
||||
metadata=body.metadata,
|
||||
)
|
||||
else:
|
||||
await run_ctx.thread_store.update_status(thread_id, "running")
|
||||
except Exception:
|
||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||
|
||||
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
|
||||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||||
merge_run_context_overrides(config, getattr(body, "context", None))
|
||||
inject_authenticated_user_context(config, request)
|
||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
||||
graph_input = normalize_input(body.input)
|
||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||
|
||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
||||
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
|
||||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||||
merge_run_context_overrides(config, getattr(body, "context", None))
|
||||
inject_authenticated_user_context(config, request)
|
||||
|
||||
task = asyncio.create_task(
|
||||
run_agent(
|
||||
bridge,
|
||||
run_mgr,
|
||||
record,
|
||||
ctx=run_ctx,
|
||||
agent_factory=agent_factory,
|
||||
graph_input=graph_input,
|
||||
config=config,
|
||||
stream_modes=stream_modes,
|
||||
stream_subgraphs=body.stream_subgraphs,
|
||||
interrupt_before=body.interrupt_before,
|
||||
interrupt_after=body.interrupt_after,
|
||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
||||
|
||||
task = asyncio.create_task(
|
||||
run_agent(
|
||||
bridge,
|
||||
run_mgr,
|
||||
record,
|
||||
ctx=run_ctx,
|
||||
agent_factory=agent_factory,
|
||||
graph_input=graph_input,
|
||||
config=config,
|
||||
stream_modes=stream_modes,
|
||||
stream_subgraphs=body.stream_subgraphs,
|
||||
interrupt_before=body.interrupt_before,
|
||||
interrupt_after=body.interrupt_after,
|
||||
)
|
||||
)
|
||||
)
|
||||
record.task = task
|
||||
record.task = task
|
||||
|
||||
# Title sync is handled by worker.py's finally block which reads the
|
||||
# title from the checkpoint and calls thread_store.update_display_name
|
||||
# after the run completes.
|
||||
# Title sync is handled by worker.py's finally block which reads the
|
||||
# title from the checkpoint and calls thread_store.update_display_name
|
||||
# after the run completes.
|
||||
|
||||
return record
|
||||
return record
|
||||
finally:
|
||||
if owner_context_token is not None:
|
||||
reset_current_user(owner_context_token)
|
||||
|
||||
|
||||
async def sse_consumer(
|
||||
|
||||
Reference in New Issue
Block a user