Address IM channel review comments

This commit is contained in:
taohe
2026-06-11 10:33:44 +08:00
parent 87200ff920
commit b7097baaec
8 changed files with 125 additions and 78 deletions
+11
View File
@@ -20,6 +20,17 @@ KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset(
)
def extract_connect_code(text: str) -> str | None:
"""Extract the one-time channel binding code from a connect command."""
parts = text.strip().split()
if len(parts) < 2:
return None
command = parts[0].lower()
if command in {"/connect", "connect"}:
return parts[1]
return None
def is_known_channel_command(text: str) -> bool:
"""Return whether text starts with a registered channel control command."""
if not text.startswith("/"):
@@ -0,0 +1,44 @@
"""Helpers for attaching persisted channel connection ownership to inbound messages."""
from __future__ import annotations
from typing import Any
from app.channels.message_bus import InboundMessage
async def attach_connection_identity(
inbound: InboundMessage,
*,
repo: Any,
provider: str,
workspace_id: str | None,
fallback_without_workspace: bool = False,
) -> InboundMessage:
"""Attach connection metadata to an inbound message when a persisted binding exists."""
if repo is None:
return inbound
workspace_candidates: list[str | None] = []
if workspace_id:
workspace_candidates.append(workspace_id)
if fallback_without_workspace:
workspace_candidates.append(None)
if not workspace_candidates:
return inbound
for candidate in workspace_candidates:
connection = await repo.find_connection_by_external_identity(
provider=provider,
external_account_id=inbound.user_id,
workspace_id=candidate,
)
if connection is None:
continue
inbound.connection_id = connection["id"]
inbound.owner_user_id = connection["owner_user_id"]
inbound.workspace_id = connection.get("workspace_id")
return inbound
return inbound
+10 -35
View File
@@ -10,7 +10,8 @@ from pathlib import Path
from typing import Any
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -18,16 +19,6 @@ logger = logging.getLogger(__name__)
_DISCORD_MAX_MESSAGE_LEN = 2000
def _extract_connect_code(text: str) -> str | None:
parts = text.strip().split()
if len(parts) < 2:
return None
command = parts[0].lower()
if command in {"/connect", "connect"}:
return parts[1]
return None
class DiscordChannel(Channel):
"""Discord bot channel.
@@ -298,7 +289,7 @@ class DiscordChannel(Channel):
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
# Don't return early if text is empty — still process the mention (e.g., create thread)
connect_code = _extract_connect_code(text)
connect_code = extract_connect_code(text)
if connect_code and await self._bind_connection_from_connect_code(message, connect_code):
return
@@ -454,29 +445,13 @@ class DiscordChannel(Channel):
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
async def _attach_connection_identity(self, inbound: InboundMessage, guild_id: str | None = None) -> InboundMessage:
if self._connection_repo is None:
return inbound
connection = None
if guild_id:
connection = await self._connection_repo.find_connection_by_external_identity(
provider="discord",
external_account_id=inbound.user_id,
workspace_id=guild_id,
)
if connection is None:
connection = await self._connection_repo.find_connection_by_external_identity(
provider="discord",
external_account_id=inbound.user_id,
workspace_id=None,
)
if connection is None:
return inbound
inbound.connection_id = connection["id"]
inbound.owner_user_id = connection["owner_user_id"]
inbound.workspace_id = connection.get("workspace_id")
return inbound
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="discord",
workspace_id=guild_id,
fallback_without_workspace=True,
)
async def _bind_connection_from_connect_code(self, message, code: str) -> bool:
if self._connection_repo is None or not code:
+6 -27
View File
@@ -9,7 +9,8 @@ from typing import Any
from markdown_to_mrkdwn import SlackMarkdownConverter
from app.channels.base import Channel
from app.channels.commands import is_known_channel_command
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -47,16 +48,6 @@ def _strip_leading_slack_bot_mention(text: str, bot_user_id: str | None) -> str:
return text[end + 1 :].lstrip()
def _extract_connect_code(text: str) -> str | None:
parts = text.strip().split()
if len(parts) < 2:
return None
command = parts[0].lower()
if command in {"/connect", "connect"}:
return parts[1]
return None
class SlackChannel(Channel):
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
@@ -325,7 +316,7 @@ class SlackChannel(Channel):
if not text:
return
connect_code = _extract_connect_code(text)
connect_code = extract_connect_code(text)
if connect_code:
if self._loop and self._loop.is_running():
asyncio.run_coroutine_threadsafe(
@@ -373,25 +364,13 @@ class SlackChannel(Channel):
await self.bus.publish_inbound(inbound)
async def _attach_connection_identity(self, inbound, *, team_id: str | None = None):
if self._connection_repo is None:
return inbound
workspace_id = str(team_id or inbound.metadata.get("team_id") or "")
if not workspace_id:
return inbound
connection = await self._connection_repo.find_connection_by_external_identity(
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="slack",
external_account_id=inbound.user_id,
workspace_id=workspace_id,
)
if connection is None:
return inbound
inbound.connection_id = connection["id"]
inbound.owner_user_id = connection["owner_user_id"]
inbound.workspace_id = connection.get("workspace_id")
return inbound
async def _bind_connection_from_connect_code(self, *, event: dict, team_id: str, code: str) -> bool:
if self._connection_repo is None or not code:
+4 -12
View File
@@ -8,6 +8,7 @@ import threading
from typing import Any
from app.channels.base import Channel
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -295,21 +296,12 @@ class TelegramChannel(Channel):
return True
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
if self._connection_repo is None:
return inbound
connection = await self._connection_repo.find_connection_by_external_identity(
return await attach_connection_identity(
inbound,
repo=self._connection_repo,
provider="telegram",
external_account_id=inbound.user_id,
workspace_id=inbound.chat_id,
)
if connection is None:
return inbound
inbound.connection_id = connection["id"]
inbound.owner_user_id = connection["owner_user_id"]
inbound.workspace_id = connection.get("workspace_id")
return inbound
def _get_bot_username(self, context) -> str | None:
bot = getattr(context, "bot", None)
@@ -203,6 +203,27 @@ def _connect_url(config: ChannelConnectionsConfig, provider: str, code: str) ->
raise HTTPException(status_code=404, detail="Unknown channel provider")
def _connection_updated_at(connection: dict[str, Any]) -> datetime:
value = connection.get("updated_at")
if isinstance(value, datetime):
return value if value.tzinfo is not None else value.replace(tzinfo=UTC)
if isinstance(value, str) and value:
try:
return datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
pass
return datetime.min.replace(tzinfo=UTC)
def _newest_connection_by_provider(connections: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
by_provider: dict[str, dict[str, Any]] = {}
for item in connections:
existing = by_provider.get(item["provider"])
if existing is None or _connection_updated_at(item) > _connection_updated_at(existing):
by_provider[item["provider"]] = item
return by_provider
@router.get("/providers", response_model=ChannelProvidersResponse)
async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
config = _get_channel_connections_config(request)
@@ -216,9 +237,7 @@ async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
raise
owner_user_id = _get_user_id(request)
connections = await repo.list_connections(owner_user_id) if repo is not None else []
by_provider: dict[str, dict[str, Any]] = {}
for item in connections:
by_provider.setdefault(item["provider"], item)
by_provider = _newest_connection_by_provider(connections)
providers: list[ChannelProviderResponse] = []
for provider, meta in _PROVIDER_META.items():