From b7097baaec8793eb54ab7cfaefe4f1f4fdf2ffed Mon Sep 17 00:00:00 2001 From: taohe Date: Thu, 11 Jun 2026 10:33:44 +0800 Subject: [PATCH] Address IM channel review comments --- backend/app/channels/commands.py | 11 +++++ backend/app/channels/connection_identity.py | 44 ++++++++++++++++++ backend/app/channels/discord.py | 45 +++++-------------- backend/app/channels/slack.py | 33 +++----------- backend/app/channels/telegram.py | 16 ++----- .../gateway/routers/channel_connections.py | 25 +++++++++-- .../persistence/channel_connections/sql.py | 5 ++- .../test_channel_connections_repository.py | 24 ++++++++++ 8 files changed, 125 insertions(+), 78 deletions(-) create mode 100644 backend/app/channels/connection_identity.py diff --git a/backend/app/channels/commands.py b/backend/app/channels/commands.py index c783899c5..86e4e9105 100644 --- a/backend/app/channels/commands.py +++ b/backend/app/channels/commands.py @@ -20,6 +20,17 @@ KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset( ) +def extract_connect_code(text: str) -> str | None: + """Extract the one-time channel binding code from a connect command.""" + parts = text.strip().split() + if len(parts) < 2: + return None + command = parts[0].lower() + if command in {"/connect", "connect"}: + return parts[1] + return None + + def is_known_channel_command(text: str) -> bool: """Return whether text starts with a registered channel control command.""" if not text.startswith("/"): diff --git a/backend/app/channels/connection_identity.py b/backend/app/channels/connection_identity.py new file mode 100644 index 000000000..162498aff --- /dev/null +++ b/backend/app/channels/connection_identity.py @@ -0,0 +1,44 @@ +"""Helpers for attaching persisted channel connection ownership to inbound messages.""" + +from __future__ import annotations + +from typing import Any + +from app.channels.message_bus import InboundMessage + + +async def attach_connection_identity( + inbound: InboundMessage, + *, + repo: Any, + provider: str, + workspace_id: str | None, + fallback_without_workspace: bool = False, +) -> InboundMessage: + """Attach connection metadata to an inbound message when a persisted binding exists.""" + if repo is None: + return inbound + + workspace_candidates: list[str | None] = [] + if workspace_id: + workspace_candidates.append(workspace_id) + if fallback_without_workspace: + workspace_candidates.append(None) + if not workspace_candidates: + return inbound + + for candidate in workspace_candidates: + connection = await repo.find_connection_by_external_identity( + provider=provider, + external_account_id=inbound.user_id, + workspace_id=candidate, + ) + if connection is None: + continue + + inbound.connection_id = connection["id"] + inbound.owner_user_id = connection["owner_user_id"] + inbound.workspace_id = connection.get("workspace_id") + return inbound + + return inbound diff --git a/backend/app/channels/discord.py b/backend/app/channels/discord.py index d997de483..d81a71fd6 100644 --- a/backend/app/channels/discord.py +++ b/backend/app/channels/discord.py @@ -10,7 +10,8 @@ from pathlib import Path from typing import Any from app.channels.base import Channel -from app.channels.commands import is_known_channel_command +from app.channels.commands import extract_connect_code, is_known_channel_command +from app.channels.connection_identity import attach_connection_identity from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment logger = logging.getLogger(__name__) @@ -18,16 +19,6 @@ logger = logging.getLogger(__name__) _DISCORD_MAX_MESSAGE_LEN = 2000 -def _extract_connect_code(text: str) -> str | None: - parts = text.strip().split() - if len(parts) < 2: - return None - command = parts[0].lower() - if command in {"/connect", "connect"}: - return parts[1] - return None - - class DiscordChannel(Channel): """Discord bot channel. @@ -298,7 +289,7 @@ class DiscordChannel(Channel): text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip() # Don't return early if text is empty — still process the mention (e.g., create thread) - connect_code = _extract_connect_code(text) + connect_code = extract_connect_code(text) if connect_code and await self._bind_connection_from_connect_code(message, connect_code): return @@ -454,29 +445,13 @@ class DiscordChannel(Channel): future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None) async def _attach_connection_identity(self, inbound: InboundMessage, guild_id: str | None = None) -> InboundMessage: - if self._connection_repo is None: - return inbound - - connection = None - if guild_id: - connection = await self._connection_repo.find_connection_by_external_identity( - provider="discord", - external_account_id=inbound.user_id, - workspace_id=guild_id, - ) - if connection is None: - connection = await self._connection_repo.find_connection_by_external_identity( - provider="discord", - external_account_id=inbound.user_id, - workspace_id=None, - ) - if connection is None: - return inbound - - inbound.connection_id = connection["id"] - inbound.owner_user_id = connection["owner_user_id"] - inbound.workspace_id = connection.get("workspace_id") - return inbound + return await attach_connection_identity( + inbound, + repo=self._connection_repo, + provider="discord", + workspace_id=guild_id, + fallback_without_workspace=True, + ) async def _bind_connection_from_connect_code(self, message, code: str) -> bool: if self._connection_repo is None or not code: diff --git a/backend/app/channels/slack.py b/backend/app/channels/slack.py index 9f2171be3..2ca713bcb 100644 --- a/backend/app/channels/slack.py +++ b/backend/app/channels/slack.py @@ -9,7 +9,8 @@ from typing import Any from markdown_to_mrkdwn import SlackMarkdownConverter from app.channels.base import Channel -from app.channels.commands import is_known_channel_command +from app.channels.commands import extract_connect_code, is_known_channel_command +from app.channels.connection_identity import attach_connection_identity from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment logger = logging.getLogger(__name__) @@ -47,16 +48,6 @@ def _strip_leading_slack_bot_mention(text: str, bot_user_id: str | None) -> str: return text[end + 1 :].lstrip() -def _extract_connect_code(text: str) -> str | None: - parts = text.strip().split() - if len(parts) < 2: - return None - command = parts[0].lower() - if command in {"/connect", "connect"}: - return parts[1] - return None - - class SlackChannel(Channel): """Slack IM channel using Socket Mode (WebSocket, no public IP). @@ -325,7 +316,7 @@ class SlackChannel(Channel): if not text: return - connect_code = _extract_connect_code(text) + connect_code = extract_connect_code(text) if connect_code: if self._loop and self._loop.is_running(): asyncio.run_coroutine_threadsafe( @@ -373,25 +364,13 @@ class SlackChannel(Channel): await self.bus.publish_inbound(inbound) async def _attach_connection_identity(self, inbound, *, team_id: str | None = None): - if self._connection_repo is None: - return inbound - workspace_id = str(team_id or inbound.metadata.get("team_id") or "") - if not workspace_id: - return inbound - - connection = await self._connection_repo.find_connection_by_external_identity( + return await attach_connection_identity( + inbound, + repo=self._connection_repo, provider="slack", - external_account_id=inbound.user_id, workspace_id=workspace_id, ) - if connection is None: - return inbound - - inbound.connection_id = connection["id"] - inbound.owner_user_id = connection["owner_user_id"] - inbound.workspace_id = connection.get("workspace_id") - return inbound async def _bind_connection_from_connect_code(self, *, event: dict, team_id: str, code: str) -> bool: if self._connection_repo is None or not code: diff --git a/backend/app/channels/telegram.py b/backend/app/channels/telegram.py index d443cd989..2e85afc75 100644 --- a/backend/app/channels/telegram.py +++ b/backend/app/channels/telegram.py @@ -8,6 +8,7 @@ import threading from typing import Any from app.channels.base import Channel +from app.channels.connection_identity import attach_connection_identity from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment logger = logging.getLogger(__name__) @@ -295,21 +296,12 @@ class TelegramChannel(Channel): return True async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage: - if self._connection_repo is None: - return inbound - - connection = await self._connection_repo.find_connection_by_external_identity( + return await attach_connection_identity( + inbound, + repo=self._connection_repo, provider="telegram", - external_account_id=inbound.user_id, workspace_id=inbound.chat_id, ) - if connection is None: - return inbound - - inbound.connection_id = connection["id"] - inbound.owner_user_id = connection["owner_user_id"] - inbound.workspace_id = connection.get("workspace_id") - return inbound def _get_bot_username(self, context) -> str | None: bot = getattr(context, "bot", None) diff --git a/backend/app/gateway/routers/channel_connections.py b/backend/app/gateway/routers/channel_connections.py index 21074a089..8981b1b86 100644 --- a/backend/app/gateway/routers/channel_connections.py +++ b/backend/app/gateway/routers/channel_connections.py @@ -203,6 +203,27 @@ def _connect_url(config: ChannelConnectionsConfig, provider: str, code: str) -> raise HTTPException(status_code=404, detail="Unknown channel provider") +def _connection_updated_at(connection: dict[str, Any]) -> datetime: + value = connection.get("updated_at") + if isinstance(value, datetime): + return value if value.tzinfo is not None else value.replace(tzinfo=UTC) + if isinstance(value, str) and value: + try: + return datetime.fromisoformat(value.replace("Z", "+00:00")) + except ValueError: + pass + return datetime.min.replace(tzinfo=UTC) + + +def _newest_connection_by_provider(connections: list[dict[str, Any]]) -> dict[str, dict[str, Any]]: + by_provider: dict[str, dict[str, Any]] = {} + for item in connections: + existing = by_provider.get(item["provider"]) + if existing is None or _connection_updated_at(item) > _connection_updated_at(existing): + by_provider[item["provider"]] = item + return by_provider + + @router.get("/providers", response_model=ChannelProvidersResponse) async def get_channel_providers(request: Request) -> ChannelProvidersResponse: config = _get_channel_connections_config(request) @@ -216,9 +237,7 @@ async def get_channel_providers(request: Request) -> ChannelProvidersResponse: raise owner_user_id = _get_user_id(request) connections = await repo.list_connections(owner_user_id) if repo is not None else [] - by_provider: dict[str, dict[str, Any]] = {} - for item in connections: - by_provider.setdefault(item["provider"], item) + by_provider = _newest_connection_by_provider(connections) providers: list[ChannelProviderResponse] = [] for provider, meta in _PROVIDER_META.items(): diff --git a/backend/packages/harness/deerflow/persistence/channel_connections/sql.py b/backend/packages/harness/deerflow/persistence/channel_connections/sql.py index 3cb40b46f..e810c359e 100644 --- a/backend/packages/harness/deerflow/persistence/channel_connections/sql.py +++ b/backend/packages/harness/deerflow/persistence/channel_connections/sql.py @@ -10,7 +10,7 @@ from datetime import UTC, datetime from typing import Any from cryptography.fernet import Fernet -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from deerflow.persistence.channel_connections.model import ( @@ -257,11 +257,14 @@ class ChannelConnectionRepository: ) -> dict[str, Any] | None: current_time = now or datetime.now(UTC) async with self.session_factory() as session: + await session.execute(delete(ChannelOAuthStateRow).where(ChannelOAuthStateRow.expires_at < current_time)) row = await session.get(ChannelOAuthStateRow, self.hash_state(state)) if row is None or row.provider != provider or row.consumed_at is not None: + await session.commit() return None expires_at = self._coerce_datetime(row.expires_at) if expires_at is not None and expires_at < current_time: + await session.commit() return None row.consumed_at = current_time diff --git a/backend/tests/test_channel_connections_repository.py b/backend/tests/test_channel_connections_repository.py index 520bf0b78..94be35679 100644 --- a/backend/tests/test_channel_connections_repository.py +++ b/backend/tests/test_channel_connections_repository.py @@ -12,6 +12,7 @@ from deerflow.persistence.channel_connections import ( ChannelConnectionRow, ChannelCredentialCipher, ChannelCredentialRow, + ChannelOAuthStateRow, ) @@ -200,3 +201,26 @@ class TestChannelConnectionRepository: assert disconnected is False assert (await repo.list_connections("alice"))[0]["status"] == "connected" + + @pytest.mark.anyio + async def test_consume_oauth_state_deletes_expired_states(self, repo): + now = datetime.now(UTC) + await repo.create_oauth_state( + owner_user_id="alice", + provider="slack", + state="expired-state", + expires_at=now - timedelta(minutes=1), + ) + await repo.create_oauth_state( + owner_user_id="alice", + provider="slack", + state="active-state", + expires_at=now + timedelta(minutes=5), + ) + + consumed = await repo.consume_oauth_state(provider="slack", state="expired-state", now=now) + + assert consumed is None + async with repo.session_factory() as session: + states = (await session.execute(select(ChannelOAuthStateRow))).scalars().all() + assert [state.state_hash for state in states] == [repo.hash_state("active-state")]