mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 17:35:57 +00:00
Add user-owned IM channel connections
This commit is contained in:
@@ -10,7 +10,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -69,6 +69,7 @@ class DiscordChannel(Channel):
|
||||
self._discord_loop: asyncio.AbstractEventLoop | None = None
|
||||
self._main_loop: asyncio.AbstractEventLoop | None = None
|
||||
self._discord_module = None
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
@@ -314,6 +315,7 @@ class DiscordChannel(Channel):
|
||||
},
|
||||
)
|
||||
inbound.topic_id = thread_id
|
||||
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
|
||||
self._publish(inbound)
|
||||
# Start typing indicator in the thread
|
||||
if typing_target:
|
||||
@@ -421,6 +423,7 @@ class DiscordChannel(Channel):
|
||||
},
|
||||
)
|
||||
inbound.topic_id = thread_id
|
||||
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
|
||||
|
||||
# Start typing indicator in the correct target (thread or channel)
|
||||
if typing_target:
|
||||
@@ -435,6 +438,31 @@ class DiscordChannel(Channel):
|
||||
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
|
||||
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage, guild_id: str | None = None) -> InboundMessage:
|
||||
if self._connection_repo is None:
|
||||
return inbound
|
||||
|
||||
connection = None
|
||||
if guild_id:
|
||||
connection = await self._connection_repo.find_connection_by_external_identity(
|
||||
provider="discord",
|
||||
external_account_id=inbound.user_id,
|
||||
workspace_id=guild_id,
|
||||
)
|
||||
if connection is None:
|
||||
connection = await self._connection_repo.find_connection_by_external_identity(
|
||||
provider="discord",
|
||||
external_account_id=inbound.user_id,
|
||||
workspace_id=None,
|
||||
)
|
||||
if connection is None:
|
||||
return inbound
|
||||
|
||||
inbound.connection_id = connection["id"]
|
||||
inbound.owner_user_id = connection["owner_user_id"]
|
||||
inbound.workspace_id = connection.get("workspace_id")
|
||||
return inbound
|
||||
|
||||
def _run_client(self) -> None:
|
||||
self._discord_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._discord_loop)
|
||||
|
||||
@@ -614,6 +614,7 @@ class ChannelManager:
|
||||
assistant_id: str = DEFAULT_ASSISTANT_ID,
|
||||
default_session: dict[str, Any] | None = None,
|
||||
channel_sessions: dict[str, Any] | None = None,
|
||||
connection_repo: Any | None = None,
|
||||
) -> None:
|
||||
self.bus = bus
|
||||
self.store = store
|
||||
@@ -623,6 +624,7 @@ class ChannelManager:
|
||||
self._assistant_id = assistant_id
|
||||
self._default_session = _as_dict(default_session)
|
||||
self._channel_sessions = dict(channel_sessions or {})
|
||||
self._connection_repo = connection_repo
|
||||
self._client = None # lazy init — langgraph_sdk async client
|
||||
self._csrf_token = generate_csrf_token()
|
||||
self._semaphore: asyncio.Semaphore | None = None
|
||||
@@ -671,12 +673,16 @@ class ChannelManager:
|
||||
configurable["checkpoint_ns"] = ""
|
||||
configurable["thread_id"] = thread_id
|
||||
|
||||
# ``user_id`` drives user-scoped filesystem buckets that only accept
|
||||
# ``[A-Za-z0-9_-]``, so normalize the channel id and keep the raw value
|
||||
# under ``channel_user_id`` for platform-facing lookups.
|
||||
# ``user_id`` drives DeerFlow-owned memory, files, and thread buckets.
|
||||
# For browser-connected IM channels, prefer the DeerFlow account that
|
||||
# owns the connection. Preserve the raw platform user under
|
||||
# ``channel_user_id`` for platform-facing lookups and audits.
|
||||
run_context_identity: dict[str, Any] = {"thread_id": thread_id}
|
||||
if msg.user_id:
|
||||
if msg.owner_user_id:
|
||||
run_context_identity["user_id"] = make_safe_user_id(msg.owner_user_id)
|
||||
elif msg.user_id:
|
||||
run_context_identity["user_id"] = make_safe_user_id(msg.user_id)
|
||||
if msg.user_id:
|
||||
run_context_identity["channel_user_id"] = msg.user_id
|
||||
|
||||
run_context = _merge_dicts(
|
||||
@@ -792,10 +798,27 @@ class ChannelManager:
|
||||
|
||||
# -- chat handling -----------------------------------------------------
|
||||
|
||||
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
||||
"""Create a new thread through Gateway and store the mapping."""
|
||||
thread = await client.threads.create()
|
||||
thread_id = thread["thread_id"]
|
||||
async def _lookup_thread_id(self, msg: InboundMessage) -> str | None:
|
||||
if msg.connection_id and self._connection_repo is not None:
|
||||
return await self._connection_repo.get_thread_id(
|
||||
msg.connection_id,
|
||||
msg.chat_id,
|
||||
msg.topic_id,
|
||||
)
|
||||
return self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
||||
|
||||
async def _store_thread_id(self, msg: InboundMessage, thread_id: str) -> None:
|
||||
if msg.connection_id and msg.owner_user_id and self._connection_repo is not None:
|
||||
await self._connection_repo.set_thread_id(
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
provider=msg.channel_name,
|
||||
external_conversation_id=msg.chat_id,
|
||||
external_topic_id=msg.topic_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
return
|
||||
|
||||
self.store.set_thread_id(
|
||||
msg.channel_name,
|
||||
msg.chat_id,
|
||||
@@ -803,6 +826,12 @@ class ChannelManager:
|
||||
topic_id=msg.topic_id,
|
||||
user_id=msg.user_id,
|
||||
)
|
||||
|
||||
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
||||
"""Create a new thread through Gateway and store the mapping."""
|
||||
thread = await client.threads.create()
|
||||
thread_id = thread["thread_id"]
|
||||
await self._store_thread_id(msg, thread_id)
|
||||
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
|
||||
return thread_id
|
||||
|
||||
@@ -812,7 +841,7 @@ class ChannelManager:
|
||||
# Look up existing DeerFlow thread.
|
||||
# topic_id may be None (e.g. Telegram private chats) — the store
|
||||
# handles this by using the "channel:chat_id" key without a topic suffix.
|
||||
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
||||
thread_id = await self._lookup_thread_id(msg)
|
||||
if thread_id:
|
||||
logger.info("[Manager] reusing thread: thread_id=%s for topic_id=%s", thread_id, msg.topic_id)
|
||||
|
||||
@@ -896,6 +925,8 @@ class ChannelManager:
|
||||
artifacts=artifacts,
|
||||
attachments=attachments,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
|
||||
)
|
||||
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
|
||||
@@ -958,6 +989,8 @@ class ChannelManager:
|
||||
text=latest_text,
|
||||
is_final=False,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_response_metadata(msg.metadata),
|
||||
)
|
||||
)
|
||||
@@ -1004,6 +1037,8 @@ class ChannelManager:
|
||||
attachments=attachments,
|
||||
is_final=True,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
|
||||
)
|
||||
)
|
||||
@@ -1028,16 +1063,10 @@ class ChannelManager:
|
||||
client = self._get_client()
|
||||
thread = await client.threads.create()
|
||||
new_thread_id = thread["thread_id"]
|
||||
self.store.set_thread_id(
|
||||
msg.channel_name,
|
||||
msg.chat_id,
|
||||
new_thread_id,
|
||||
topic_id=msg.topic_id,
|
||||
user_id=msg.user_id,
|
||||
)
|
||||
await self._store_thread_id(msg, new_thread_id)
|
||||
reply = "New conversation started."
|
||||
elif command == "status":
|
||||
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
||||
thread_id = await self._lookup_thread_id(msg)
|
||||
reply = f"Active thread: {thread_id}" if thread_id else "No active conversation."
|
||||
elif command == "models":
|
||||
reply = await self._fetch_gateway("/api/models", "models")
|
||||
@@ -1060,9 +1089,11 @@ class ChannelManager:
|
||||
outbound = OutboundMessage(
|
||||
channel_name=msg.channel_name,
|
||||
chat_id=msg.chat_id,
|
||||
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
|
||||
thread_id=await self._lookup_thread_id(msg) or "",
|
||||
text=reply,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_slim_metadata(msg.metadata),
|
||||
)
|
||||
await self.bus.publish_outbound(outbound)
|
||||
@@ -1098,9 +1129,11 @@ class ChannelManager:
|
||||
outbound = OutboundMessage(
|
||||
channel_name=msg.channel_name,
|
||||
chat_id=msg.chat_id,
|
||||
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
|
||||
thread_id=await self._lookup_thread_id(msg) or "",
|
||||
text=error_text,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_slim_metadata(msg.metadata),
|
||||
)
|
||||
await self.bus.publish_outbound(outbound)
|
||||
|
||||
@@ -44,6 +44,12 @@ class InboundMessage:
|
||||
Messages sharing the same ``topic_id`` within a ``chat_id`` will
|
||||
reuse the same DeerFlow thread. When ``None``, each message
|
||||
creates a new thread (one-shot Q&A).
|
||||
connection_id: Optional DeerFlow channel connection id. When present,
|
||||
conversation mapping is scoped by the connection instead of the
|
||||
legacy global ``channel_name:chat_id[:topic_id]`` key.
|
||||
owner_user_id: DeerFlow user id that owns the channel connection.
|
||||
Platform user ids stay in ``user_id``.
|
||||
workspace_id: Optional external workspace/guild/team id.
|
||||
files: Optional list of file attachments (platform-specific dicts).
|
||||
metadata: Arbitrary extra data from the channel.
|
||||
created_at: Unix timestamp when the message was created.
|
||||
@@ -56,6 +62,9 @@ class InboundMessage:
|
||||
msg_type: InboundMessageType = InboundMessageType.CHAT
|
||||
thread_ts: str | None = None
|
||||
topic_id: str | None = None
|
||||
connection_id: str | None = None
|
||||
owner_user_id: str | None = None
|
||||
workspace_id: str | None = None
|
||||
files: list[dict[str, Any]] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
created_at: float = field(default_factory=time.time)
|
||||
@@ -95,6 +104,9 @@ class OutboundMessage:
|
||||
is_final: Whether this is the final message in the response stream.
|
||||
thread_ts: Optional platform thread identifier for threaded replies.
|
||||
metadata: Arbitrary extra data.
|
||||
connection_id: Optional DeerFlow channel connection id used for
|
||||
connection-specific outbound credentials.
|
||||
owner_user_id: DeerFlow user id that owns the channel connection.
|
||||
created_at: Unix timestamp.
|
||||
"""
|
||||
|
||||
@@ -106,6 +118,8 @@ class OutboundMessage:
|
||||
attachments: list[ResolvedAttachment] = field(default_factory=list)
|
||||
is_final: bool = True
|
||||
thread_ts: str | None = None
|
||||
connection_id: str | None = None
|
||||
owner_user_id: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""Provider-specific helpers for user-owned IM channel connections."""
|
||||
@@ -0,0 +1,110 @@
|
||||
"""Discord OAuth helpers for user-owned channel connections."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
DISCORD_API_BASE_URL = "https://discord.com/api/v10"
|
||||
DISCORD_TOKEN_URL = f"{DISCORD_API_BASE_URL}/oauth2/token"
|
||||
DISCORD_CURRENT_USER_URL = f"{DISCORD_API_BASE_URL}/users/@me"
|
||||
DISCORD_CURRENT_USER_GUILDS_URL = f"{DISCORD_API_BASE_URL}/users/@me/guilds"
|
||||
|
||||
|
||||
class DiscordConnectError(RuntimeError):
|
||||
"""Raised when Discord OAuth fails."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DiscordIdentity:
|
||||
user_id: str
|
||||
display_name: str | None
|
||||
username: str | None
|
||||
guilds: list[dict[str, Any]]
|
||||
access_token: str
|
||||
refresh_token: str | None
|
||||
token_type: str | None
|
||||
scopes: list[str]
|
||||
expires_at: datetime | None
|
||||
raw_token: dict[str, Any]
|
||||
|
||||
|
||||
def _split_scopes(value: str | None) -> list[str]:
|
||||
if not value:
|
||||
return []
|
||||
return [scope.strip() for scope in value.replace(",", " ").split() if scope.strip()]
|
||||
|
||||
|
||||
def _display_name(user: dict[str, Any]) -> str | None:
|
||||
global_name = user.get("global_name")
|
||||
if isinstance(global_name, str) and global_name:
|
||||
return global_name
|
||||
username = user.get("username")
|
||||
return str(username) if username else None
|
||||
|
||||
|
||||
async def complete_discord_oauth(
|
||||
*,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
) -> DiscordIdentity:
|
||||
async def _complete(client: httpx.AsyncClient) -> DiscordIdentity:
|
||||
token_response = await client.post(
|
||||
DISCORD_TOKEN_URL,
|
||||
data={
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
timeout=10,
|
||||
)
|
||||
token_response.raise_for_status()
|
||||
token = token_response.json()
|
||||
access_token = token.get("access_token")
|
||||
if not access_token:
|
||||
raise DiscordConnectError("Discord OAuth response did not include an access token")
|
||||
|
||||
auth_headers = {"Authorization": f"Bearer {access_token}"}
|
||||
user_response = await client.get(DISCORD_CURRENT_USER_URL, headers=auth_headers, timeout=10)
|
||||
user_response.raise_for_status()
|
||||
user = user_response.json()
|
||||
user_id = user.get("id")
|
||||
if not user_id:
|
||||
raise DiscordConnectError("Discord user response did not include a user id")
|
||||
|
||||
guilds_response = await client.get(DISCORD_CURRENT_USER_GUILDS_URL, headers=auth_headers, timeout=10)
|
||||
guilds: list[dict[str, Any]] = []
|
||||
if guilds_response.status_code == 200:
|
||||
guilds = guilds_response.json()
|
||||
|
||||
expires_at = None
|
||||
expires_in = token.get("expires_in")
|
||||
if isinstance(expires_in, int | float):
|
||||
expires_at = datetime.now(UTC) + timedelta(seconds=float(expires_in))
|
||||
|
||||
return DiscordIdentity(
|
||||
user_id=str(user_id),
|
||||
display_name=_display_name(user),
|
||||
username=user.get("username"),
|
||||
guilds=guilds,
|
||||
access_token=str(access_token),
|
||||
refresh_token=token.get("refresh_token"),
|
||||
token_type=token.get("token_type"),
|
||||
scopes=_split_scopes(token.get("scope")),
|
||||
expires_at=expires_at,
|
||||
raw_token=token,
|
||||
)
|
||||
|
||||
if http_client is None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
return await _complete(client)
|
||||
return await _complete(http_client)
|
||||
@@ -0,0 +1,110 @@
|
||||
"""Slack OAuth and Events helpers for user-owned channel connections."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
SLACK_OAUTH_ACCESS_URL = "https://slack.com/api/oauth.v2.access"
|
||||
SLACK_SIGNATURE_VERSION = "v0"
|
||||
SLACK_SIGNATURE_TOLERANCE_SECONDS = 60 * 5
|
||||
|
||||
|
||||
class SlackConnectError(RuntimeError):
|
||||
"""Raised when Slack OAuth or request verification fails."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SlackInstall:
|
||||
team_id: str
|
||||
team_name: str | None
|
||||
authed_user_id: str
|
||||
bot_user_id: str | None
|
||||
bot_access_token: str
|
||||
scopes: list[str]
|
||||
raw: dict[str, Any]
|
||||
|
||||
|
||||
def verify_slack_signature(
|
||||
*,
|
||||
signing_secret: str,
|
||||
timestamp: str | None,
|
||||
body: bytes,
|
||||
signature: str | None,
|
||||
now: int | None = None,
|
||||
) -> bool:
|
||||
if not signing_secret or not timestamp or not signature:
|
||||
return False
|
||||
|
||||
try:
|
||||
timestamp_int = int(timestamp)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
current_time = int(time.time()) if now is None else now
|
||||
if abs(current_time - timestamp_int) > SLACK_SIGNATURE_TOLERANCE_SECONDS:
|
||||
return False
|
||||
|
||||
base = f"{SLACK_SIGNATURE_VERSION}:{timestamp}:".encode() + body
|
||||
digest = hmac.new(signing_secret.encode("utf-8"), base, hashlib.sha256).hexdigest()
|
||||
expected = f"{SLACK_SIGNATURE_VERSION}={digest}"
|
||||
return hmac.compare_digest(expected, signature)
|
||||
|
||||
|
||||
def _split_scopes(value: str | None) -> list[str]:
|
||||
if not value:
|
||||
return []
|
||||
return [scope.strip() for scope in value.split(",") if scope.strip()]
|
||||
|
||||
|
||||
async def exchange_slack_oauth_code(
|
||||
*,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
) -> SlackInstall:
|
||||
async def _post(client: httpx.AsyncClient) -> dict[str, Any]:
|
||||
response = await client.post(
|
||||
SLACK_OAUTH_ACCESS_URL,
|
||||
data={
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
if http_client is None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
payload = await _post(client)
|
||||
else:
|
||||
payload = await _post(http_client)
|
||||
|
||||
if not payload.get("ok"):
|
||||
raise SlackConnectError(str(payload.get("error") or "Slack OAuth exchange failed"))
|
||||
|
||||
access_token = payload.get("access_token")
|
||||
team = payload.get("team") or {}
|
||||
authed_user = payload.get("authed_user") or {}
|
||||
if not access_token or not team.get("id") or not authed_user.get("id"):
|
||||
raise SlackConnectError("Slack OAuth response did not include required installation fields")
|
||||
|
||||
return SlackInstall(
|
||||
team_id=str(team["id"]),
|
||||
team_name=team.get("name"),
|
||||
authed_user_id=str(authed_user["id"]),
|
||||
bot_user_id=payload.get("bot_user_id"),
|
||||
bot_access_token=str(access_token),
|
||||
scopes=_split_scopes(payload.get("scope")),
|
||||
raw=payload,
|
||||
)
|
||||
@@ -52,6 +52,56 @@ def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str,
|
||||
return default
|
||||
|
||||
|
||||
def _merge_channel_connection_runtime_config(channels_config: dict[str, Any], app_config: AppConfig) -> None:
|
||||
connection_config = getattr(app_config, "channel_connections", None)
|
||||
if connection_config is None or not getattr(connection_config, "enabled", False):
|
||||
return
|
||||
|
||||
telegram = getattr(connection_config, "telegram", None)
|
||||
if telegram is not None and getattr(telegram, "enabled", False) and getattr(telegram, "configured", False):
|
||||
telegram_config = dict(channels_config.get("telegram", {})) if isinstance(channels_config.get("telegram"), dict) else {}
|
||||
telegram_config.setdefault("enabled", True)
|
||||
telegram_config.setdefault("bot_token", telegram.bot_token)
|
||||
channels_config["telegram"] = telegram_config
|
||||
|
||||
slack = getattr(connection_config, "slack", None)
|
||||
if slack is not None and getattr(slack, "enabled", False) and getattr(slack, "configured", False):
|
||||
slack_config = dict(channels_config.get("slack", {})) if isinstance(channels_config.get("slack"), dict) else {}
|
||||
slack_config.setdefault("enabled", True)
|
||||
slack_config.setdefault("event_delivery", slack.event_delivery)
|
||||
slack_config.setdefault("signing_secret", slack.signing_secret)
|
||||
channels_config["slack"] = slack_config
|
||||
|
||||
discord = getattr(connection_config, "discord", None)
|
||||
if discord is not None and getattr(discord, "enabled", False) and getattr(discord, "configured", False):
|
||||
discord_config = dict(channels_config.get("discord", {})) if isinstance(channels_config.get("discord"), dict) else {}
|
||||
discord_config.setdefault("enabled", True)
|
||||
discord_config.setdefault("bot_token", discord.bot_token)
|
||||
channels_config["discord"] = discord_config
|
||||
|
||||
|
||||
def _make_connection_repo(app_config: AppConfig):
|
||||
connection_config = getattr(app_config, "channel_connections", None)
|
||||
if connection_config is None or not getattr(connection_config, "enabled", False):
|
||||
return None
|
||||
encryption_key = getattr(connection_config, "encryption_key", "")
|
||||
if not encryption_key:
|
||||
return None
|
||||
|
||||
try:
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
except Exception:
|
||||
logger.exception("Failed to import channel connection repository")
|
||||
return None
|
||||
|
||||
session_factory = get_session_factory()
|
||||
if session_factory is None:
|
||||
logger.warning("Channel connections are enabled but database persistence is not available")
|
||||
return None
|
||||
return ChannelConnectionRepository(session_factory, cipher=ChannelCredentialCipher.from_key(encryption_key))
|
||||
|
||||
|
||||
class ChannelService:
|
||||
"""Manages the lifecycle of all configured IM channels.
|
||||
|
||||
@@ -59,9 +109,10 @@ class ChannelService:
|
||||
instantiates enabled channels, and starts the ChannelManager dispatcher.
|
||||
"""
|
||||
|
||||
def __init__(self, channels_config: dict[str, Any] | None = None) -> None:
|
||||
def __init__(self, channels_config: dict[str, Any] | None = None, *, connection_repo: Any | None = None) -> None:
|
||||
self.bus = MessageBus()
|
||||
self.store = ChannelStore()
|
||||
self._connection_repo = connection_repo
|
||||
config = dict(channels_config or {})
|
||||
langgraph_url = _resolve_service_url(config, "langgraph_url", _CHANNELS_LANGGRAPH_URL_ENV, DEFAULT_LANGGRAPH_URL)
|
||||
gateway_url = _resolve_service_url(config, "gateway_url", _CHANNELS_GATEWAY_URL_ENV, DEFAULT_GATEWAY_URL)
|
||||
@@ -74,6 +125,7 @@ class ChannelService:
|
||||
gateway_url=gateway_url,
|
||||
default_session=default_session if isinstance(default_session, dict) else None,
|
||||
channel_sessions=channel_sessions,
|
||||
connection_repo=connection_repo,
|
||||
)
|
||||
self._channels: dict[str, Any] = {} # name -> Channel instance
|
||||
self._config = config
|
||||
@@ -90,8 +142,9 @@ class ChannelService:
|
||||
# extra fields are allowed by AppConfig (extra="allow")
|
||||
extra = app_config.model_extra or {}
|
||||
if "channels" in extra:
|
||||
channels_config = extra["channels"]
|
||||
return cls(channels_config=channels_config)
|
||||
channels_config = dict(extra["channels"] or {})
|
||||
_merge_channel_connection_runtime_config(channels_config, app_config)
|
||||
return cls(channels_config=channels_config, connection_repo=_make_connection_repo(app_config))
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the manager and all enabled channels."""
|
||||
@@ -169,6 +222,8 @@ class ChannelService:
|
||||
try:
|
||||
config = dict(config)
|
||||
config["channel_store"] = self.store
|
||||
if self._connection_repo is not None:
|
||||
config["connection_repo"] = self._connection_repo
|
||||
channel = channel_cls(bus=self.bus, config=config)
|
||||
self._channels[name] = channel
|
||||
await channel.start()
|
||||
|
||||
@@ -49,6 +49,8 @@ class SlackChannel(Channel):
|
||||
self._web_client = None
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
self._web_client_factory = config.get("web_client_factory")
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
@@ -63,15 +65,24 @@ class SlackChannel(Channel):
|
||||
return
|
||||
|
||||
self._SocketModeResponse = SocketModeResponse
|
||||
if self._web_client_factory is None:
|
||||
self._web_client_factory = WebClient
|
||||
|
||||
bot_token = self.config.get("bot_token", "")
|
||||
app_token = self.config.get("app_token", "")
|
||||
|
||||
if self._connection_repo is not None and self.config.get("event_delivery") == "http":
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self._running = True
|
||||
self.bus.subscribe_outbound(self._on_outbound)
|
||||
logger.info("Slack channel started in HTTP Events mode")
|
||||
return
|
||||
|
||||
if not bot_token or not app_token:
|
||||
logger.error("Slack channel requires bot_token and app_token")
|
||||
return
|
||||
|
||||
self._web_client = WebClient(token=bot_token)
|
||||
self._web_client = self._web_client_factory(token=bot_token)
|
||||
self._socket_client = SocketModeClient(
|
||||
app_token=app_token,
|
||||
web_client=self._web_client,
|
||||
@@ -96,7 +107,8 @@ class SlackChannel(Channel):
|
||||
logger.info("Slack channel stopped")
|
||||
|
||||
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||||
if not self._web_client:
|
||||
web_client = await self._get_web_client_for_message(msg)
|
||||
if not web_client:
|
||||
return
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
@@ -109,11 +121,12 @@ class SlackChannel(Channel):
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(_max_retries):
|
||||
try:
|
||||
await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs)
|
||||
await asyncio.to_thread(web_client.chat_postMessage, **kwargs)
|
||||
# Add a completion reaction to the thread root
|
||||
if msg.thread_ts:
|
||||
await asyncio.to_thread(
|
||||
self._add_reaction,
|
||||
self._add_reaction_with_client,
|
||||
web_client,
|
||||
msg.chat_id,
|
||||
msg.thread_ts,
|
||||
"white_check_mark",
|
||||
@@ -137,7 +150,8 @@ class SlackChannel(Channel):
|
||||
if msg.thread_ts:
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self._add_reaction,
|
||||
self._add_reaction_with_client,
|
||||
web_client,
|
||||
msg.chat_id,
|
||||
msg.thread_ts,
|
||||
"x",
|
||||
@@ -149,7 +163,8 @@ class SlackChannel(Channel):
|
||||
raise last_exc
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
if not self._web_client:
|
||||
web_client = await self._get_web_client_for_message(msg)
|
||||
if not web_client:
|
||||
return False
|
||||
|
||||
try:
|
||||
@@ -162,7 +177,7 @@ class SlackChannel(Channel):
|
||||
if msg.thread_ts:
|
||||
kwargs["thread_ts"] = msg.thread_ts
|
||||
|
||||
await asyncio.to_thread(self._web_client.files_upload_v2, **kwargs)
|
||||
await asyncio.to_thread(web_client.files_upload_v2, **kwargs)
|
||||
logger.info("[Slack] file uploaded: %s to channel=%s", attachment.filename, msg.chat_id)
|
||||
return True
|
||||
except Exception:
|
||||
@@ -171,12 +186,24 @@ class SlackChannel(Channel):
|
||||
|
||||
# -- internal ----------------------------------------------------------
|
||||
|
||||
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
|
||||
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
|
||||
if not self._web_client:
|
||||
return
|
||||
async def _get_web_client_for_message(self, msg: OutboundMessage):
|
||||
if msg.connection_id and self._connection_repo is not None:
|
||||
credentials = await self._connection_repo.get_credentials(msg.connection_id)
|
||||
access_token = credentials.get("access_token") if credentials else None
|
||||
if not access_token:
|
||||
logger.warning("[Slack] no bot token found for connection=%s", msg.connection_id)
|
||||
return None
|
||||
if self._web_client_factory is None:
|
||||
from slack_sdk import WebClient
|
||||
|
||||
self._web_client_factory = WebClient
|
||||
return self._web_client_factory(token=access_token)
|
||||
return self._web_client
|
||||
|
||||
@staticmethod
|
||||
def _add_reaction_with_client(web_client, channel_id: str, timestamp: str, emoji: str) -> None:
|
||||
try:
|
||||
self._web_client.reactions_add(
|
||||
web_client.reactions_add(
|
||||
channel=channel_id,
|
||||
timestamp=timestamp,
|
||||
name=emoji,
|
||||
@@ -185,6 +212,12 @@ class SlackChannel(Channel):
|
||||
if "already_reacted" not in str(exc):
|
||||
logger.warning("[Slack] failed to add reaction %s: %s", emoji, exc)
|
||||
|
||||
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
|
||||
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
|
||||
if not self._web_client:
|
||||
return
|
||||
self._add_reaction_with_client(self._web_client, channel_id, timestamp, emoji)
|
||||
|
||||
def _send_running_reply(self, channel_id: str, thread_ts: str) -> None:
|
||||
"""Send a 'Working on it......' reply in the thread (called from SDK thread)."""
|
||||
if not self._web_client:
|
||||
|
||||
@@ -35,6 +35,7 @@ class TelegramChannel(Channel):
|
||||
pass
|
||||
# chat_id -> last sent message_id for threaded replies
|
||||
self._last_bot_message: dict[str, int] = {}
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
@@ -171,6 +172,26 @@ class TelegramChannel(Channel):
|
||||
logger.exception("[Telegram] failed to send file: %s", attachment.filename)
|
||||
return False
|
||||
|
||||
async def process_webhook_update(self, payload: dict[str, Any]) -> bool:
|
||||
if not self._application:
|
||||
return False
|
||||
try:
|
||||
from telegram import Update
|
||||
except ImportError:
|
||||
logger.error("python-telegram-bot is not installed. Install it with: uv add python-telegram-bot")
|
||||
return False
|
||||
|
||||
update = Update.de_json(payload, self._application.bot)
|
||||
if update is None:
|
||||
return False
|
||||
|
||||
if self._tg_loop and self._tg_loop.is_running():
|
||||
future = asyncio.run_coroutine_threadsafe(self._application.process_update(update), self._tg_loop)
|
||||
await asyncio.wrap_future(future)
|
||||
else:
|
||||
await self._application.process_update(update)
|
||||
return True
|
||||
|
||||
# -- helpers -----------------------------------------------------------
|
||||
|
||||
async def _send_running_reply(self, chat_id: str, reply_to_message_id: int) -> None:
|
||||
@@ -228,10 +249,72 @@ class TelegramChannel(Channel):
|
||||
return True
|
||||
return user_id in self._allowed_users
|
||||
|
||||
@staticmethod
|
||||
def _telegram_display_name(user) -> str:
|
||||
full_name = getattr(user, "full_name", None)
|
||||
if isinstance(full_name, str) and full_name:
|
||||
return full_name
|
||||
username = getattr(user, "username", None)
|
||||
if isinstance(username, str) and username:
|
||||
return username
|
||||
return str(getattr(user, "id", ""))
|
||||
|
||||
async def _bind_connection_from_start_token(self, update, state_token: str) -> bool:
|
||||
if self._connection_repo is None or not state_token:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="telegram", state=state_token)
|
||||
if state is None:
|
||||
await update.message.reply_text("Telegram connection link is invalid or expired.")
|
||||
return True
|
||||
|
||||
owner_user_id = state["owner_user_id"]
|
||||
user_id = str(update.effective_user.id)
|
||||
chat_id = str(update.effective_chat.id)
|
||||
connection = await self._connection_repo.upsert_connection(
|
||||
owner_user_id=owner_user_id,
|
||||
provider="telegram",
|
||||
external_account_id=user_id,
|
||||
external_account_name=self._telegram_display_name(update.effective_user),
|
||||
workspace_id=chat_id,
|
||||
workspace_name=None,
|
||||
metadata={
|
||||
"chat_id": chat_id,
|
||||
"chat_type": update.effective_chat.type,
|
||||
"telegram_username": getattr(update.effective_user, "username", None),
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
logger.info("[Telegram] bound chat=%s user=%s to DeerFlow user=%s connection=%s", chat_id, user_id, owner_user_id, connection["id"])
|
||||
await update.message.reply_text("Telegram connected to DeerFlow.")
|
||||
return True
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||
if self._connection_repo is None:
|
||||
return inbound
|
||||
|
||||
connection = await self._connection_repo.find_connection_by_external_identity(
|
||||
provider="telegram",
|
||||
external_account_id=inbound.user_id,
|
||||
workspace_id=inbound.chat_id,
|
||||
)
|
||||
if connection is None:
|
||||
return inbound
|
||||
|
||||
inbound.connection_id = connection["id"]
|
||||
inbound.owner_user_id = connection["owner_user_id"]
|
||||
inbound.workspace_id = connection.get("workspace_id")
|
||||
return inbound
|
||||
|
||||
async def _cmd_start(self, update, context) -> None:
|
||||
"""Handle /start command."""
|
||||
if not self._check_user(update.effective_user.id):
|
||||
return
|
||||
args = getattr(context, "args", []) if context is not None else []
|
||||
if args:
|
||||
handled = await self._bind_connection_from_start_token(update, str(args[0]))
|
||||
if handled:
|
||||
return
|
||||
await update.message.reply_text("Welcome to DeerFlow! Send me a message to start a conversation.\nType /help for available commands.")
|
||||
|
||||
async def _process_incoming_with_reply(self, chat_id: str, msg_id: int, inbound: InboundMessage) -> None:
|
||||
@@ -267,6 +350,7 @@ class TelegramChannel(Channel):
|
||||
thread_ts=msg_id,
|
||||
)
|
||||
inbound.topic_id = topic_id
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
||||
@@ -309,6 +393,7 @@ class TelegramChannel(Channel):
|
||||
thread_ts=msg_id,
|
||||
)
|
||||
inbound.topic_id = topic_id
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.gateway.routers import (
|
||||
artifacts,
|
||||
assistants_compat,
|
||||
auth,
|
||||
channel_connections,
|
||||
channels,
|
||||
feedback,
|
||||
mcp,
|
||||
@@ -376,6 +377,9 @@ This gateway provides runtime endpoints for agent runs plus custom endpoints for
|
||||
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
|
||||
app.include_router(suggestions.router)
|
||||
|
||||
# User-facing IM channel connection API is mounted at /api/channels
|
||||
app.include_router(channel_connections.router)
|
||||
|
||||
# Channels API is mounted at /api/channels
|
||||
app.include_router(channels.router)
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ _PUBLIC_PATH_PREFIXES: tuple[str, ...] = (
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/api/channels/webhooks/",
|
||||
)
|
||||
|
||||
# Exact auth paths that are public (login/register/status check).
|
||||
@@ -38,6 +39,8 @@ _PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/setup-status",
|
||||
"/api/v1/auth/initialize",
|
||||
"/api/channels/slack/callback",
|
||||
"/api/channels/discord/callback",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -39,6 +39,8 @@ def should_check_csrf(request: Request) -> bool:
|
||||
return False
|
||||
|
||||
path = request.url.path.rstrip("/")
|
||||
if path.startswith("/api/channels/webhooks/"):
|
||||
return False
|
||||
# Exempt /api/v1/auth/me endpoint
|
||||
if path == "/api/v1/auth/me":
|
||||
return False
|
||||
|
||||
@@ -0,0 +1,487 @@
|
||||
"""Browser-facing APIs for user-owned IM channel connections."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, Response
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.responses import PlainTextResponse, RedirectResponse
|
||||
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType
|
||||
from app.channels.providers import discord_connect, slack_connect
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
|
||||
router = APIRouter(prefix="/api/channels", tags=["channel-connections"])
|
||||
|
||||
_STATE_TTL_SECONDS = 600
|
||||
|
||||
|
||||
class ChannelProviderResponse(BaseModel):
|
||||
provider: str
|
||||
display_name: str
|
||||
enabled: bool
|
||||
configured: bool
|
||||
auth_mode: str
|
||||
connection_status: str
|
||||
|
||||
|
||||
class ChannelProvidersResponse(BaseModel):
|
||||
enabled: bool
|
||||
providers: list[ChannelProviderResponse]
|
||||
|
||||
|
||||
class ChannelConnectionResponse(BaseModel):
|
||||
id: str
|
||||
provider: str
|
||||
status: str
|
||||
external_account_id: str | None = None
|
||||
external_account_name: str | None = None
|
||||
workspace_id: str | None = None
|
||||
workspace_name: str | None = None
|
||||
scopes: list[str] = Field(default_factory=list)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ChannelConnectionsResponse(BaseModel):
|
||||
connections: list[ChannelConnectionResponse]
|
||||
|
||||
|
||||
class ChannelConnectResponse(BaseModel):
|
||||
provider: str
|
||||
mode: str
|
||||
url: str
|
||||
expires_in: int
|
||||
|
||||
|
||||
_PROVIDER_META: dict[str, dict[str, str]] = {
|
||||
"telegram": {"display_name": "Telegram", "auth_mode": "deep_link"},
|
||||
"slack": {"display_name": "Slack", "auth_mode": "oauth"},
|
||||
"discord": {"display_name": "Discord", "auth_mode": "oauth_and_bot_install"},
|
||||
}
|
||||
|
||||
|
||||
def _get_user_id(request: Request) -> str:
|
||||
user = getattr(request.state, "user", None)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
return str(user.id)
|
||||
|
||||
|
||||
def _get_channel_connections_config(request: Request) -> ChannelConnectionsConfig:
|
||||
config = getattr(request.app.state, "channel_connections_config", None)
|
||||
if isinstance(config, ChannelConnectionsConfig):
|
||||
return config
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
return get_app_config().channel_connections
|
||||
|
||||
|
||||
def _get_repository(request: Request, config: ChannelConnectionsConfig) -> ChannelConnectionRepository:
|
||||
repo = getattr(request.app.state, "channel_connection_repo", None)
|
||||
if isinstance(repo, ChannelConnectionRepository):
|
||||
return repo
|
||||
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
raise HTTPException(status_code=503, detail="Channel connection persistence is not available")
|
||||
if not config.encryption_key:
|
||||
raise HTTPException(status_code=503, detail="Channel connection encryption key is not configured")
|
||||
|
||||
repo = ChannelConnectionRepository(sf, cipher=ChannelCredentialCipher.from_key(config.encryption_key))
|
||||
request.app.state.channel_connection_repo = repo
|
||||
return repo
|
||||
|
||||
|
||||
def _provider_config(config: ChannelConnectionsConfig, provider: str):
|
||||
provider_config = getattr(config, provider, None)
|
||||
if provider_config is None:
|
||||
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||
return provider_config
|
||||
|
||||
|
||||
async def _create_state(
|
||||
repo: ChannelConnectionRepository,
|
||||
*,
|
||||
owner_user_id: str,
|
||||
provider: str,
|
||||
requested_scopes: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
state = secrets.token_urlsafe(32)
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id=owner_user_id,
|
||||
provider=provider,
|
||||
state=state,
|
||||
requested_scopes=requested_scopes,
|
||||
metadata=metadata,
|
||||
expires_at=datetime.now(UTC) + timedelta(seconds=_STATE_TTL_SECONDS),
|
||||
)
|
||||
return state
|
||||
|
||||
|
||||
def _build_connect_url(config: ChannelConnectionsConfig, provider: str, state: str) -> str:
|
||||
provider_config = _provider_config(config, provider)
|
||||
if provider == "telegram":
|
||||
return f"https://t.me/{provider_config.bot_username}?start={state}"
|
||||
|
||||
redirect_uri = f"{config.public_base_url.rstrip('/')}/api/channels/{provider}/callback"
|
||||
if provider == "slack":
|
||||
query = urlencode(
|
||||
{
|
||||
"client_id": provider_config.client_id,
|
||||
"scope": ",".join(provider_config.scopes),
|
||||
"redirect_uri": redirect_uri,
|
||||
"state": state,
|
||||
}
|
||||
)
|
||||
return f"https://slack.com/oauth/v2/authorize?{query}"
|
||||
|
||||
if provider == "discord":
|
||||
scopes = "identify guilds bot applications.commands"
|
||||
query = urlencode(
|
||||
{
|
||||
"client_id": provider_config.client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": scopes,
|
||||
"state": state,
|
||||
"permissions": provider_config.permissions,
|
||||
}
|
||||
)
|
||||
return f"https://discord.com/oauth2/authorize?{query}"
|
||||
|
||||
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||
|
||||
|
||||
def _callback_redirect(provider: str, state_data: dict[str, Any]) -> RedirectResponse:
|
||||
redirect_after = state_data.get("redirect_after")
|
||||
if isinstance(redirect_after, str) and redirect_after:
|
||||
return RedirectResponse(redirect_after)
|
||||
return RedirectResponse(f"/workspace?channel_connected={provider}")
|
||||
|
||||
|
||||
def _get_message_bus(request: Request):
|
||||
bus = getattr(request.app.state, "channel_message_bus", None)
|
||||
if bus is not None:
|
||||
return bus
|
||||
try:
|
||||
from app.channels.service import get_channel_service
|
||||
except Exception:
|
||||
return None
|
||||
service = get_channel_service()
|
||||
return service.bus if service is not None else None
|
||||
|
||||
|
||||
def _get_channel_instance(request: Request, name: str):
|
||||
channel_instances = getattr(request.app.state, "channel_instances", None)
|
||||
if isinstance(channel_instances, dict) and name in channel_instances:
|
||||
return channel_instances[name]
|
||||
try:
|
||||
from app.channels.service import get_channel_service
|
||||
except Exception:
|
||||
return None
|
||||
service = get_channel_service()
|
||||
return service.get_channel(name) if service is not None else None
|
||||
|
||||
|
||||
async def _publish_slack_event(
|
||||
*,
|
||||
repo: ChannelConnectionRepository,
|
||||
bus: Any,
|
||||
payload: dict[str, Any],
|
||||
) -> bool:
|
||||
event = payload.get("event") or {}
|
||||
event_type = event.get("type")
|
||||
if event_type not in {"message", "app_mention"}:
|
||||
return False
|
||||
if event.get("bot_id") or event.get("subtype"):
|
||||
return False
|
||||
|
||||
text = str(event.get("text") or "").strip()
|
||||
user_id = str(event.get("user") or "")
|
||||
channel_id = str(event.get("channel") or "")
|
||||
team_id = str(payload.get("team_id") or event.get("team") or event.get("team_id") or "")
|
||||
if not text or not user_id or not channel_id or not team_id:
|
||||
return False
|
||||
|
||||
connection = await repo.find_connection_by_external_identity(
|
||||
provider="slack",
|
||||
external_account_id=user_id,
|
||||
workspace_id=team_id,
|
||||
)
|
||||
if connection is None:
|
||||
return False
|
||||
|
||||
thread_ts = str(event.get("thread_ts") or event.get("ts") or "")
|
||||
inbound = InboundMessage(
|
||||
channel_name="slack",
|
||||
chat_id=channel_id,
|
||||
user_id=user_id,
|
||||
text=text,
|
||||
msg_type=InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT,
|
||||
thread_ts=thread_ts,
|
||||
metadata={"team_id": team_id, "event_id": payload.get("event_id")},
|
||||
connection_id=connection["id"],
|
||||
owner_user_id=connection["owner_user_id"],
|
||||
workspace_id=team_id,
|
||||
)
|
||||
inbound.topic_id = thread_ts or None
|
||||
await bus.publish_inbound(inbound)
|
||||
return True
|
||||
|
||||
|
||||
@router.get("/providers", response_model=ChannelProvidersResponse)
|
||||
async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
||||
config = _get_channel_connections_config(request)
|
||||
repo = _get_repository(request, config) if config.enabled and config.encryption_key else None
|
||||
owner_user_id = _get_user_id(request)
|
||||
connections = await repo.list_connections(owner_user_id) if repo is not None else []
|
||||
by_provider = {item["provider"]: item for item in connections}
|
||||
|
||||
providers: list[ChannelProviderResponse] = []
|
||||
for provider, meta in _PROVIDER_META.items():
|
||||
status = config.provider_status(provider)
|
||||
connection = by_provider.get(provider)
|
||||
providers.append(
|
||||
ChannelProviderResponse(
|
||||
provider=provider,
|
||||
display_name=meta["display_name"],
|
||||
enabled=status["enabled"],
|
||||
configured=status["configured"],
|
||||
auth_mode=meta["auth_mode"],
|
||||
connection_status=connection["status"] if connection else "not_connected",
|
||||
)
|
||||
)
|
||||
return ChannelProvidersResponse(enabled=config.enabled, providers=providers)
|
||||
|
||||
|
||||
@router.get("/connections", response_model=ChannelConnectionsResponse)
|
||||
async def get_channel_connections(request: Request) -> ChannelConnectionsResponse:
|
||||
config = _get_channel_connections_config(request)
|
||||
if not config.enabled:
|
||||
return ChannelConnectionsResponse(connections=[])
|
||||
repo = _get_repository(request, config)
|
||||
rows = await repo.list_connections(_get_user_id(request))
|
||||
return ChannelConnectionsResponse(connections=[ChannelConnectionResponse(**row) for row in rows])
|
||||
|
||||
|
||||
@router.delete("/connections/{connection_id}", status_code=204)
|
||||
async def disconnect_channel_connection(connection_id: str, request: Request) -> Response:
|
||||
config = _get_channel_connections_config(request)
|
||||
if not config.enabled:
|
||||
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||
|
||||
repo = _get_repository(request, config)
|
||||
disconnected = await repo.disconnect_connection(
|
||||
connection_id=connection_id,
|
||||
owner_user_id=_get_user_id(request),
|
||||
)
|
||||
if not disconnected:
|
||||
raise HTTPException(status_code=404, detail="Channel connection not found")
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.get("/slack/callback")
|
||||
async def slack_oauth_callback(request: Request, code: str | None = None, state: str | None = None, error: str | None = None):
|
||||
if error:
|
||||
raise HTTPException(status_code=400, detail=f"Slack OAuth failed: {error}")
|
||||
if not code or not state:
|
||||
raise HTTPException(status_code=400, detail="Slack OAuth callback is missing code or state")
|
||||
|
||||
config = _get_channel_connections_config(request)
|
||||
provider_config = _provider_config(config, "slack")
|
||||
if not config.enabled or not provider_config.enabled or not provider_config.configured:
|
||||
raise HTTPException(status_code=400, detail="Channel provider is not configured")
|
||||
|
||||
repo = _get_repository(request, config)
|
||||
state_data = await repo.consume_oauth_state(provider="slack", state=state)
|
||||
if state_data is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired OAuth state")
|
||||
|
||||
redirect_uri = f"{config.public_base_url.rstrip('/')}/api/channels/slack/callback"
|
||||
install = await slack_connect.exchange_slack_oauth_code(
|
||||
client_id=provider_config.client_id,
|
||||
client_secret=provider_config.client_secret,
|
||||
code=code,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
connection = await repo.upsert_connection(
|
||||
owner_user_id=state_data["owner_user_id"],
|
||||
provider="slack",
|
||||
external_account_id=install.authed_user_id,
|
||||
workspace_id=install.team_id,
|
||||
workspace_name=install.team_name,
|
||||
bot_user_id=install.bot_user_id,
|
||||
scopes=install.scopes or state_data.get("requested_scopes", []),
|
||||
metadata={"team_id": install.team_id, "team_name": install.team_name},
|
||||
status="connected",
|
||||
)
|
||||
await repo.store_credentials(
|
||||
connection["id"],
|
||||
access_token=install.bot_access_token,
|
||||
token_type="Bearer",
|
||||
extra={"bot_user_id": install.bot_user_id, "team_id": install.team_id},
|
||||
)
|
||||
return _callback_redirect("slack", state_data)
|
||||
|
||||
|
||||
@router.get("/discord/callback")
|
||||
async def discord_oauth_callback(request: Request, code: str | None = None, state: str | None = None, error: str | None = None):
|
||||
if error:
|
||||
raise HTTPException(status_code=400, detail=f"Discord OAuth failed: {error}")
|
||||
if not code or not state:
|
||||
raise HTTPException(status_code=400, detail="Discord OAuth callback is missing code or state")
|
||||
|
||||
config = _get_channel_connections_config(request)
|
||||
provider_config = _provider_config(config, "discord")
|
||||
if not config.enabled or not provider_config.enabled or not provider_config.configured:
|
||||
raise HTTPException(status_code=400, detail="Channel provider is not configured")
|
||||
|
||||
repo = _get_repository(request, config)
|
||||
state_data = await repo.consume_oauth_state(provider="discord", state=state)
|
||||
if state_data is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired OAuth state")
|
||||
|
||||
redirect_uri = f"{config.public_base_url.rstrip('/')}/api/channels/discord/callback"
|
||||
identity = await discord_connect.complete_discord_oauth(
|
||||
client_id=provider_config.client_id,
|
||||
client_secret=provider_config.client_secret,
|
||||
code=code,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
connection = await repo.upsert_connection(
|
||||
owner_user_id=state_data["owner_user_id"],
|
||||
provider="discord",
|
||||
external_account_id=identity.user_id,
|
||||
external_account_name=identity.display_name or identity.username,
|
||||
scopes=identity.scopes or state_data.get("requested_scopes", []),
|
||||
capabilities={"message_content_intent_required": provider_config.require_message_content_intent},
|
||||
metadata={"username": identity.username, "guilds": identity.guilds},
|
||||
status="connected",
|
||||
)
|
||||
await repo.store_credentials(
|
||||
connection["id"],
|
||||
access_token=identity.access_token,
|
||||
refresh_token=identity.refresh_token,
|
||||
token_type=identity.token_type,
|
||||
expires_at=identity.expires_at,
|
||||
extra={"guilds": identity.guilds},
|
||||
)
|
||||
return _callback_redirect("discord", state_data)
|
||||
|
||||
|
||||
@router.post("/webhooks/slack/events")
|
||||
async def slack_events_webhook(request: Request):
|
||||
config = _get_channel_connections_config(request)
|
||||
provider_config = _provider_config(config, "slack")
|
||||
if not config.enabled or not provider_config.enabled or not provider_config.configured:
|
||||
raise HTTPException(status_code=400, detail="Channel provider is not configured")
|
||||
|
||||
body = await request.body()
|
||||
if not slack_connect.verify_slack_signature(
|
||||
signing_secret=provider_config.signing_secret,
|
||||
timestamp=request.headers.get("X-Slack-Request-Timestamp"),
|
||||
body=body,
|
||||
signature=request.headers.get("X-Slack-Signature"),
|
||||
):
|
||||
raise HTTPException(status_code=401, detail="Invalid Slack signature")
|
||||
|
||||
try:
|
||||
payload = json.loads(body.decode("utf-8"))
|
||||
except json.JSONDecodeError as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid Slack payload") from exc
|
||||
|
||||
if payload.get("type") == "url_verification":
|
||||
challenge = payload.get("challenge")
|
||||
if not isinstance(challenge, str):
|
||||
raise HTTPException(status_code=400, detail="Slack challenge is missing")
|
||||
return PlainTextResponse(challenge)
|
||||
|
||||
repo = _get_repository(request, config)
|
||||
delivery_id = str(payload.get("event_id") or hashlib.sha256(body).hexdigest())
|
||||
payload_hash = hashlib.sha256(body).hexdigest()
|
||||
event = payload.get("event") or {}
|
||||
is_new = await repo.record_webhook_delivery(
|
||||
provider="slack",
|
||||
delivery_id=delivery_id,
|
||||
payload_sha256=payload_hash,
|
||||
event_type=event.get("type"),
|
||||
)
|
||||
if not is_new:
|
||||
return {"ok": True, "duplicate": True, "processed": False}
|
||||
|
||||
bus = _get_message_bus(request)
|
||||
processed = False
|
||||
if bus is not None:
|
||||
processed = await _publish_slack_event(repo=repo, bus=bus, payload=payload)
|
||||
return {"ok": True, "processed": processed}
|
||||
|
||||
|
||||
@router.post("/webhooks/telegram")
|
||||
async def telegram_webhook(request: Request):
|
||||
config = _get_channel_connections_config(request)
|
||||
provider_config = _provider_config(config, "telegram")
|
||||
if not config.enabled or not provider_config.enabled or not provider_config.configured:
|
||||
raise HTTPException(status_code=400, detail="Channel provider is not configured")
|
||||
|
||||
secret_header = request.headers.get("X-Telegram-Bot-Api-Secret-Token")
|
||||
if not secret_header or not secrets.compare_digest(secret_header, provider_config.webhook_secret):
|
||||
raise HTTPException(status_code=401, detail="Invalid Telegram webhook secret")
|
||||
|
||||
body = await request.body()
|
||||
try:
|
||||
payload = json.loads(body.decode("utf-8"))
|
||||
except json.JSONDecodeError as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid Telegram payload") from exc
|
||||
|
||||
repo = _get_repository(request, config)
|
||||
delivery_id = str(payload.get("update_id") or hashlib.sha256(body).hexdigest())
|
||||
is_new = await repo.record_webhook_delivery(
|
||||
provider="telegram",
|
||||
delivery_id=delivery_id,
|
||||
payload_sha256=hashlib.sha256(body).hexdigest(),
|
||||
event_type="update",
|
||||
)
|
||||
if not is_new:
|
||||
return {"ok": True, "duplicate": True, "processed": False}
|
||||
|
||||
processed = False
|
||||
channel = _get_channel_instance(request, "telegram")
|
||||
process_update = getattr(channel, "process_webhook_update", None)
|
||||
if process_update is not None:
|
||||
processed = bool(await process_update(payload))
|
||||
return {"ok": True, "processed": processed}
|
||||
|
||||
|
||||
@router.post("/{provider}/connect", response_model=ChannelConnectResponse)
|
||||
async def connect_channel_provider(provider: str, request: Request) -> ChannelConnectResponse:
|
||||
config = _get_channel_connections_config(request)
|
||||
if not config.enabled:
|
||||
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||
|
||||
provider_config = _provider_config(config, provider)
|
||||
if not provider_config.enabled or not provider_config.configured:
|
||||
raise HTTPException(status_code=400, detail="Channel provider is not configured")
|
||||
|
||||
repo = _get_repository(request, config)
|
||||
state = await _create_state(
|
||||
repo,
|
||||
owner_user_id=_get_user_id(request),
|
||||
provider=provider,
|
||||
requested_scopes=getattr(provider_config, "scopes", []),
|
||||
)
|
||||
return ChannelConnectResponse(
|
||||
provider=provider,
|
||||
mode=_PROVIDER_META[provider]["auth_mode"],
|
||||
url=_build_connect_url(config, provider, state),
|
||||
expires_in=_STATE_TTL_SECONDS,
|
||||
)
|
||||
Reference in New Issue
Block a user