Align IM connections with local channels

This commit is contained in:
taohe
2026-06-10 22:16:47 +08:00
parent 92c185b90d
commit d06643d8a2
33 changed files with 588 additions and 1536 deletions
+59
View File
@@ -18,6 +18,16 @@ logger = logging.getLogger(__name__)
_DISCORD_MAX_MESSAGE_LEN = 2000
def _extract_connect_code(text: str) -> str | None:
parts = text.strip().split()
if len(parts) < 2:
return None
command = parts[0].lower()
if command in {"/connect", "connect"}:
return parts[1]
return None
class DiscordChannel(Channel):
"""Discord bot channel.
@@ -288,6 +298,10 @@ class DiscordChannel(Channel):
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
# Don't return early if text is empty — still process the mention (e.g., create thread)
connect_code = _extract_connect_code(text)
if connect_code and await self._bind_connection_from_connect_code(message, connect_code):
return
# --- Determine thread/channel routing and typing target ---
thread_id = None
chat_id = None
@@ -464,6 +478,51 @@ class DiscordChannel(Channel):
inbound.workspace_id = connection.get("workspace_id")
return inbound
async def _bind_connection_from_connect_code(self, message, code: str) -> bool:
if self._connection_repo is None or not code:
return False
state = await self._connection_repo.consume_oauth_state(provider="discord", state=code)
if state is None:
await self._send_connection_reply(message, "Discord connection code is invalid or expired.")
return True
guild = getattr(message, "guild", None)
channel = getattr(message, "channel", None)
author = getattr(message, "author", None)
user_id = str(getattr(author, "id", "") or "")
if not user_id:
await self._send_connection_reply(message, "Discord connection could not be completed from this message.")
return True
guild_id = str(getattr(guild, "id", "") or "") or None
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="discord",
external_account_id=user_id,
external_account_name=getattr(author, "display_name", None) or getattr(author, "name", None),
workspace_id=guild_id,
workspace_name=getattr(guild, "name", None) if guild is not None else None,
metadata={
"guild_id": guild_id,
"channel_id": str(getattr(channel, "id", "") or ""),
},
status="connected",
)
await self._send_connection_reply(message, "Discord connected to DeerFlow.")
return True
@staticmethod
async def _send_connection_reply(message, text: str) -> None:
channel = getattr(message, "channel", None)
send = getattr(channel, "send", None)
if send is None:
return
try:
await send(text)
except Exception:
logger.exception("[Discord] failed to send connection reply")
def _run_client(self) -> None:
self._discord_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._discord_loop)
@@ -1 +0,0 @@
"""Provider-specific helpers for user-owned IM channel connections."""
@@ -1,110 +0,0 @@
"""Discord OAuth helpers for user-owned channel connections."""
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import Any
import httpx
DISCORD_API_BASE_URL = "https://discord.com/api/v10"
DISCORD_TOKEN_URL = f"{DISCORD_API_BASE_URL}/oauth2/token"
DISCORD_CURRENT_USER_URL = f"{DISCORD_API_BASE_URL}/users/@me"
DISCORD_CURRENT_USER_GUILDS_URL = f"{DISCORD_API_BASE_URL}/users/@me/guilds"
class DiscordConnectError(RuntimeError):
"""Raised when Discord OAuth fails."""
@dataclass(frozen=True)
class DiscordIdentity:
user_id: str
display_name: str | None
username: str | None
guilds: list[dict[str, Any]]
access_token: str
refresh_token: str | None
token_type: str | None
scopes: list[str]
expires_at: datetime | None
raw_token: dict[str, Any]
def _split_scopes(value: str | None) -> list[str]:
if not value:
return []
return [scope.strip() for scope in value.replace(",", " ").split() if scope.strip()]
def _display_name(user: dict[str, Any]) -> str | None:
global_name = user.get("global_name")
if isinstance(global_name, str) and global_name:
return global_name
username = user.get("username")
return str(username) if username else None
async def complete_discord_oauth(
*,
client_id: str,
client_secret: str,
code: str,
redirect_uri: str,
http_client: httpx.AsyncClient | None = None,
) -> DiscordIdentity:
async def _complete(client: httpx.AsyncClient) -> DiscordIdentity:
token_response = await client.post(
DISCORD_TOKEN_URL,
data={
"client_id": client_id,
"client_secret": client_secret,
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirect_uri,
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
timeout=10,
)
token_response.raise_for_status()
token = token_response.json()
access_token = token.get("access_token")
if not access_token:
raise DiscordConnectError("Discord OAuth response did not include an access token")
auth_headers = {"Authorization": f"Bearer {access_token}"}
user_response = await client.get(DISCORD_CURRENT_USER_URL, headers=auth_headers, timeout=10)
user_response.raise_for_status()
user = user_response.json()
user_id = user.get("id")
if not user_id:
raise DiscordConnectError("Discord user response did not include a user id")
guilds_response = await client.get(DISCORD_CURRENT_USER_GUILDS_URL, headers=auth_headers, timeout=10)
guilds: list[dict[str, Any]] = []
if guilds_response.status_code == 200:
guilds = guilds_response.json()
expires_at = None
expires_in = token.get("expires_in")
if isinstance(expires_in, int | float):
expires_at = datetime.now(UTC) + timedelta(seconds=float(expires_in))
return DiscordIdentity(
user_id=str(user_id),
display_name=_display_name(user),
username=user.get("username"),
guilds=guilds,
access_token=str(access_token),
refresh_token=token.get("refresh_token"),
token_type=token.get("token_type"),
scopes=_split_scopes(token.get("scope")),
expires_at=expires_at,
raw_token=token,
)
if http_client is None:
async with httpx.AsyncClient() as client:
return await _complete(client)
return await _complete(http_client)
@@ -1,110 +0,0 @@
"""Slack OAuth and Events helpers for user-owned channel connections."""
from __future__ import annotations
import hashlib
import hmac
import time
from dataclasses import dataclass
from typing import Any
import httpx
SLACK_OAUTH_ACCESS_URL = "https://slack.com/api/oauth.v2.access"
SLACK_SIGNATURE_VERSION = "v0"
SLACK_SIGNATURE_TOLERANCE_SECONDS = 60 * 5
class SlackConnectError(RuntimeError):
"""Raised when Slack OAuth or request verification fails."""
@dataclass(frozen=True)
class SlackInstall:
team_id: str
team_name: str | None
authed_user_id: str
bot_user_id: str | None
bot_access_token: str
scopes: list[str]
raw: dict[str, Any]
def verify_slack_signature(
*,
signing_secret: str,
timestamp: str | None,
body: bytes,
signature: str | None,
now: int | None = None,
) -> bool:
if not signing_secret or not timestamp or not signature:
return False
try:
timestamp_int = int(timestamp)
except (TypeError, ValueError):
return False
current_time = int(time.time()) if now is None else now
if abs(current_time - timestamp_int) > SLACK_SIGNATURE_TOLERANCE_SECONDS:
return False
base = f"{SLACK_SIGNATURE_VERSION}:{timestamp}:".encode() + body
digest = hmac.new(signing_secret.encode("utf-8"), base, hashlib.sha256).hexdigest()
expected = f"{SLACK_SIGNATURE_VERSION}={digest}"
return hmac.compare_digest(expected, signature)
def _split_scopes(value: str | None) -> list[str]:
if not value:
return []
return [scope.strip() for scope in value.split(",") if scope.strip()]
async def exchange_slack_oauth_code(
*,
client_id: str,
client_secret: str,
code: str,
redirect_uri: str,
http_client: httpx.AsyncClient | None = None,
) -> SlackInstall:
async def _post(client: httpx.AsyncClient) -> dict[str, Any]:
response = await client.post(
SLACK_OAUTH_ACCESS_URL,
data={
"client_id": client_id,
"client_secret": client_secret,
"code": code,
"redirect_uri": redirect_uri,
},
timeout=10,
)
response.raise_for_status()
return response.json()
if http_client is None:
async with httpx.AsyncClient() as client:
payload = await _post(client)
else:
payload = await _post(http_client)
if not payload.get("ok"):
raise SlackConnectError(str(payload.get("error") or "Slack OAuth exchange failed"))
access_token = payload.get("access_token")
team = payload.get("team") or {}
authed_user = payload.get("authed_user") or {}
if not access_token or not team.get("id") or not authed_user.get("id"):
raise SlackConnectError("Slack OAuth response did not include required installation fields")
return SlackInstall(
team_id=str(team["id"]),
team_name=team.get("name"),
authed_user_id=str(authed_user["id"]),
bot_user_id=payload.get("bot_user_id"),
bot_access_token=str(access_token),
scopes=_split_scopes(payload.get("scope")),
raw=payload,
)
+2 -26
View File
@@ -57,37 +57,14 @@ def _merge_channel_connection_runtime_config(channels_config: dict[str, Any], ap
if connection_config is None or not getattr(connection_config, "enabled", False):
return
telegram = getattr(connection_config, "telegram", None)
if telegram is not None and getattr(telegram, "enabled", False) and getattr(telegram, "configured", False):
telegram_config = dict(channels_config.get("telegram", {})) if isinstance(channels_config.get("telegram"), dict) else {}
telegram_config.setdefault("enabled", True)
telegram_config.setdefault("bot_token", telegram.bot_token)
channels_config["telegram"] = telegram_config
slack = getattr(connection_config, "slack", None)
if slack is not None and getattr(slack, "enabled", False) and getattr(slack, "configured", False):
slack_config = dict(channels_config.get("slack", {})) if isinstance(channels_config.get("slack"), dict) else {}
slack_config.setdefault("enabled", True)
slack_config.setdefault("event_delivery", slack.event_delivery)
slack_config.setdefault("signing_secret", slack.signing_secret)
channels_config["slack"] = slack_config
discord = getattr(connection_config, "discord", None)
if discord is not None and getattr(discord, "enabled", False) and getattr(discord, "configured", False):
discord_config = dict(channels_config.get("discord", {})) if isinstance(channels_config.get("discord"), dict) else {}
discord_config.setdefault("enabled", True)
discord_config.setdefault("bot_token", discord.bot_token)
channels_config["discord"] = discord_config
def _make_connection_repo(app_config: AppConfig):
connection_config = getattr(app_config, "channel_connections", None)
if connection_config is None or not getattr(connection_config, "enabled", False):
return None
encryption_key = getattr(connection_config, "encryption_key", "")
try:
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
from deerflow.persistence.channel_connections import ChannelConnectionRepository
from deerflow.persistence.engine import get_session_factory
except Exception:
logger.exception("Failed to import channel connection repository")
@@ -97,8 +74,7 @@ def _make_connection_repo(app_config: AppConfig):
if session_factory is None:
logger.warning("Channel connections are enabled but database persistence is not available")
return None
cipher = ChannelCredentialCipher.from_key(encryption_key) if encryption_key else None
return ChannelConnectionRepository(session_factory, cipher=cipher)
return ChannelConnectionRepository(session_factory)
class ChannelService:
+99 -5
View File
@@ -47,6 +47,16 @@ def _strip_leading_slack_bot_mention(text: str, bot_user_id: str | None) -> str:
return text[end + 1 :].lstrip()
def _extract_connect_code(text: str) -> str | None:
parts = text.strip().split()
if len(parts) < 2:
return None
command = parts[0].lower()
if command in {"/connect", "connect"}:
return parts[1]
return None
class SlackChannel(Channel):
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
@@ -219,8 +229,7 @@ class SlackChannel(Channel):
credentials = await self._connection_repo.get_credentials(msg.connection_id)
access_token = credentials.get("access_token") if credentials else None
if not access_token:
logger.warning("[Slack] no bot token found for connection=%s", msg.connection_id)
return None
return self._web_client
if self._web_client_factory is None:
from slack_sdk import WebClient
@@ -282,12 +291,15 @@ class SlackChannel(Channel):
# Handle message events (DM or @mention)
if etype in ("message", "app_mention"):
self._handle_message_event(event)
self._handle_message_event(
event,
team_id=req.payload.get("team_id") or req.payload.get("team") or event.get("team"),
)
except Exception:
logger.exception("Error processing Slack event")
def _handle_message_event(self, event: dict) -> None:
def _handle_message_event(self, event: dict, *, team_id: str | None = None) -> None:
# Ignore bot messages
if event.get("bot_id") or event.get("subtype"):
return
@@ -305,6 +317,19 @@ class SlackChannel(Channel):
if not text:
return
connect_code = _extract_connect_code(text)
if connect_code:
if self._loop and self._loop.is_running():
asyncio.run_coroutine_threadsafe(
self._bind_connection_from_connect_code(
event=event,
team_id=str(team_id or event.get("team") or ""),
code=connect_code,
),
self._loop,
)
return
channel_id = event.get("channel", "")
thread_ts = event.get("thread_ts") or event.get("ts", "")
@@ -330,4 +355,73 @@ class SlackChannel(Channel):
self._add_reaction(channel_id, event.get("ts", thread_ts), "eyes")
# Send "running" reply first (fire-and-forget from SDK thread)
self._send_running_reply(channel_id, thread_ts)
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
if self._connection_repo is None:
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
else:
asyncio.run_coroutine_threadsafe(self._publish_inbound_with_connection(inbound, team_id=team_id), self._loop)
async def _publish_inbound_with_connection(self, inbound, *, team_id: str | None = None) -> None:
inbound = await self._attach_connection_identity(inbound, team_id=team_id)
await self.bus.publish_inbound(inbound)
async def _attach_connection_identity(self, inbound, *, team_id: str | None = None):
if self._connection_repo is None:
return inbound
workspace_id = str(team_id or inbound.metadata.get("team_id") or "")
if not workspace_id:
return inbound
connection = await self._connection_repo.find_connection_by_external_identity(
provider="slack",
external_account_id=inbound.user_id,
workspace_id=workspace_id,
)
if connection is None:
return inbound
inbound.connection_id = connection["id"]
inbound.owner_user_id = connection["owner_user_id"]
inbound.workspace_id = connection.get("workspace_id")
return inbound
async def _bind_connection_from_connect_code(self, *, event: dict, team_id: str, code: str) -> bool:
if self._connection_repo is None or not code:
return False
channel_id = str(event.get("channel") or "")
thread_ts = str(event.get("thread_ts") or event.get("ts") or "")
state = await self._connection_repo.consume_oauth_state(provider="slack", state=code)
if state is None:
self._post_connection_reply(channel_id, "Slack connection code is invalid or expired.", thread_ts)
return True
user_id = str(event.get("user") or "")
if not user_id or not team_id:
self._post_connection_reply(channel_id, "Slack connection could not be completed from this message.", thread_ts)
return True
await self._connection_repo.upsert_connection(
owner_user_id=state["owner_user_id"],
provider="slack",
external_account_id=user_id,
workspace_id=team_id,
metadata={
"team_id": team_id,
"channel_id": channel_id,
},
status="connected",
)
self._post_connection_reply(channel_id, "Slack connected to DeerFlow.", thread_ts)
return True
def _post_connection_reply(self, channel_id: str, text: str, thread_ts: str | None = None) -> None:
if not self._web_client or not channel_id:
return
kwargs: dict[str, Any] = {"channel": channel_id, "text": text}
if thread_ts:
kwargs["thread_ts"] = thread_ts
try:
self._web_client.chat_postMessage(**kwargs)
except Exception:
logger.exception("[Slack] failed to send connection reply in channel=%s", channel_id)