Address IM channel review comments

This commit is contained in:
taohe
2026-06-11 10:33:44 +08:00
parent 87200ff920
commit b7097baaec
8 changed files with 125 additions and 78 deletions
+11
View File
@@ -20,6 +20,17 @@ KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset(
) )
def extract_connect_code(text: str) -> str | None:
"""Extract the one-time channel binding code from a connect command."""
parts = text.strip().split()
if len(parts) < 2:
return None
command = parts[0].lower()
if command in {"/connect", "connect"}:
return parts[1]
return None
def is_known_channel_command(text: str) -> bool: def is_known_channel_command(text: str) -> bool:
"""Return whether text starts with a registered channel control command.""" """Return whether text starts with a registered channel control command."""
if not text.startswith("/"): if not text.startswith("/"):
@@ -0,0 +1,44 @@
"""Helpers for attaching persisted channel connection ownership to inbound messages."""
from __future__ import annotations
from typing import Any
from app.channels.message_bus import InboundMessage
async def attach_connection_identity(
inbound: InboundMessage,
*,
repo: Any,
provider: str,
workspace_id: str | None,
fallback_without_workspace: bool = False,
) -> InboundMessage:
"""Attach connection metadata to an inbound message when a persisted binding exists."""
if repo is None:
return inbound
workspace_candidates: list[str | None] = []
if workspace_id:
workspace_candidates.append(workspace_id)
if fallback_without_workspace:
workspace_candidates.append(None)
if not workspace_candidates:
return inbound
for candidate in workspace_candidates:
connection = await repo.find_connection_by_external_identity(
provider=provider,
external_account_id=inbound.user_id,
workspace_id=candidate,
)
if connection is None:
continue
inbound.connection_id = connection["id"]
inbound.owner_user_id = connection["owner_user_id"]
inbound.workspace_id = connection.get("workspace_id")
return inbound
return inbound
+7 -32
View File
@@ -10,7 +10,8 @@ from pathlib import Path
from typing import Any from typing import Any
from app.channels.base import Channel from app.channels.base import Channel
from app.channels.commands import is_known_channel_command from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -18,16 +19,6 @@ logger = logging.getLogger(__name__)
_DISCORD_MAX_MESSAGE_LEN = 2000 _DISCORD_MAX_MESSAGE_LEN = 2000
def _extract_connect_code(text: str) -> str | None:
parts = text.strip().split()
if len(parts) < 2:
return None
command = parts[0].lower()
if command in {"/connect", "connect"}:
return parts[1]
return None
class DiscordChannel(Channel): class DiscordChannel(Channel):
"""Discord bot channel. """Discord bot channel.
@@ -298,7 +289,7 @@ class DiscordChannel(Channel):
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip() text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
# Don't return early if text is empty — still process the mention (e.g., create thread) # Don't return early if text is empty — still process the mention (e.g., create thread)
connect_code = _extract_connect_code(text) connect_code = extract_connect_code(text)
if connect_code and await self._bind_connection_from_connect_code(message, connect_code): if connect_code and await self._bind_connection_from_connect_code(message, connect_code):
return return
@@ -454,29 +445,13 @@ class DiscordChannel(Channel):
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None) future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
async def _attach_connection_identity(self, inbound: InboundMessage, guild_id: str | None = None) -> InboundMessage: async def _attach_connection_identity(self, inbound: InboundMessage, guild_id: str | None = None) -> InboundMessage:
if self._connection_repo is None: return await attach_connection_identity(
return inbound inbound,
repo=self._connection_repo,
connection = None
if guild_id:
connection = await self._connection_repo.find_connection_by_external_identity(
provider="discord", provider="discord",
external_account_id=inbound.user_id,
workspace_id=guild_id, workspace_id=guild_id,
fallback_without_workspace=True,
) )
if connection is None:
connection = await self._connection_repo.find_connection_by_external_identity(
provider="discord",
external_account_id=inbound.user_id,
workspace_id=None,
)
if connection is None:
return inbound
inbound.connection_id = connection["id"]
inbound.owner_user_id = connection["owner_user_id"]
inbound.workspace_id = connection.get("workspace_id")
return inbound
async def _bind_connection_from_connect_code(self, message, code: str) -> bool: async def _bind_connection_from_connect_code(self, message, code: str) -> bool:
if self._connection_repo is None or not code: if self._connection_repo is None or not code:
+6 -27
View File
@@ -9,7 +9,8 @@ from typing import Any
from markdown_to_mrkdwn import SlackMarkdownConverter from markdown_to_mrkdwn import SlackMarkdownConverter
from app.channels.base import Channel from app.channels.base import Channel
from app.channels.commands import is_known_channel_command from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -47,16 +48,6 @@ def _strip_leading_slack_bot_mention(text: str, bot_user_id: str | None) -> str:
return text[end + 1 :].lstrip() return text[end + 1 :].lstrip()
def _extract_connect_code(text: str) -> str | None:
parts = text.strip().split()
if len(parts) < 2:
return None
command = parts[0].lower()
if command in {"/connect", "connect"}:
return parts[1]
return None
class SlackChannel(Channel): class SlackChannel(Channel):
"""Slack IM channel using Socket Mode (WebSocket, no public IP). """Slack IM channel using Socket Mode (WebSocket, no public IP).
@@ -325,7 +316,7 @@ class SlackChannel(Channel):
if not text: if not text:
return return
connect_code = _extract_connect_code(text) connect_code = extract_connect_code(text)
if connect_code: if connect_code:
if self._loop and self._loop.is_running(): if self._loop and self._loop.is_running():
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(
@@ -373,25 +364,13 @@ class SlackChannel(Channel):
await self.bus.publish_inbound(inbound) await self.bus.publish_inbound(inbound)
async def _attach_connection_identity(self, inbound, *, team_id: str | None = None): async def _attach_connection_identity(self, inbound, *, team_id: str | None = None):
if self._connection_repo is None:
return inbound
workspace_id = str(team_id or inbound.metadata.get("team_id") or "") workspace_id = str(team_id or inbound.metadata.get("team_id") or "")
if not workspace_id: return await attach_connection_identity(
return inbound inbound,
repo=self._connection_repo,
connection = await self._connection_repo.find_connection_by_external_identity(
provider="slack", provider="slack",
external_account_id=inbound.user_id,
workspace_id=workspace_id, workspace_id=workspace_id,
) )
if connection is None:
return inbound
inbound.connection_id = connection["id"]
inbound.owner_user_id = connection["owner_user_id"]
inbound.workspace_id = connection.get("workspace_id")
return inbound
async def _bind_connection_from_connect_code(self, *, event: dict, team_id: str, code: str) -> bool: async def _bind_connection_from_connect_code(self, *, event: dict, team_id: str, code: str) -> bool:
if self._connection_repo is None or not code: if self._connection_repo is None or not code:
+4 -12
View File
@@ -8,6 +8,7 @@ import threading
from typing import Any from typing import Any
from app.channels.base import Channel from app.channels.base import Channel
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -295,21 +296,12 @@ class TelegramChannel(Channel):
return True return True
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage: async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
if self._connection_repo is None: return await attach_connection_identity(
return inbound inbound,
repo=self._connection_repo,
connection = await self._connection_repo.find_connection_by_external_identity(
provider="telegram", provider="telegram",
external_account_id=inbound.user_id,
workspace_id=inbound.chat_id, workspace_id=inbound.chat_id,
) )
if connection is None:
return inbound
inbound.connection_id = connection["id"]
inbound.owner_user_id = connection["owner_user_id"]
inbound.workspace_id = connection.get("workspace_id")
return inbound
def _get_bot_username(self, context) -> str | None: def _get_bot_username(self, context) -> str | None:
bot = getattr(context, "bot", None) bot = getattr(context, "bot", None)
@@ -203,6 +203,27 @@ def _connect_url(config: ChannelConnectionsConfig, provider: str, code: str) ->
raise HTTPException(status_code=404, detail="Unknown channel provider") raise HTTPException(status_code=404, detail="Unknown channel provider")
def _connection_updated_at(connection: dict[str, Any]) -> datetime:
value = connection.get("updated_at")
if isinstance(value, datetime):
return value if value.tzinfo is not None else value.replace(tzinfo=UTC)
if isinstance(value, str) and value:
try:
return datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
pass
return datetime.min.replace(tzinfo=UTC)
def _newest_connection_by_provider(connections: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
by_provider: dict[str, dict[str, Any]] = {}
for item in connections:
existing = by_provider.get(item["provider"])
if existing is None or _connection_updated_at(item) > _connection_updated_at(existing):
by_provider[item["provider"]] = item
return by_provider
@router.get("/providers", response_model=ChannelProvidersResponse) @router.get("/providers", response_model=ChannelProvidersResponse)
async def get_channel_providers(request: Request) -> ChannelProvidersResponse: async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
config = _get_channel_connections_config(request) config = _get_channel_connections_config(request)
@@ -216,9 +237,7 @@ async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
raise raise
owner_user_id = _get_user_id(request) owner_user_id = _get_user_id(request)
connections = await repo.list_connections(owner_user_id) if repo is not None else [] connections = await repo.list_connections(owner_user_id) if repo is not None else []
by_provider: dict[str, dict[str, Any]] = {} by_provider = _newest_connection_by_provider(connections)
for item in connections:
by_provider.setdefault(item["provider"], item)
providers: list[ChannelProviderResponse] = [] providers: list[ChannelProviderResponse] = []
for provider, meta in _PROVIDER_META.items(): for provider, meta in _PROVIDER_META.items():
@@ -10,7 +10,7 @@ from datetime import UTC, datetime
from typing import Any from typing import Any
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from sqlalchemy import select from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.channel_connections.model import ( from deerflow.persistence.channel_connections.model import (
@@ -257,11 +257,14 @@ class ChannelConnectionRepository:
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
current_time = now or datetime.now(UTC) current_time = now or datetime.now(UTC)
async with self.session_factory() as session: async with self.session_factory() as session:
await session.execute(delete(ChannelOAuthStateRow).where(ChannelOAuthStateRow.expires_at < current_time))
row = await session.get(ChannelOAuthStateRow, self.hash_state(state)) row = await session.get(ChannelOAuthStateRow, self.hash_state(state))
if row is None or row.provider != provider or row.consumed_at is not None: if row is None or row.provider != provider or row.consumed_at is not None:
await session.commit()
return None return None
expires_at = self._coerce_datetime(row.expires_at) expires_at = self._coerce_datetime(row.expires_at)
if expires_at is not None and expires_at < current_time: if expires_at is not None and expires_at < current_time:
await session.commit()
return None return None
row.consumed_at = current_time row.consumed_at = current_time
@@ -12,6 +12,7 @@ from deerflow.persistence.channel_connections import (
ChannelConnectionRow, ChannelConnectionRow,
ChannelCredentialCipher, ChannelCredentialCipher,
ChannelCredentialRow, ChannelCredentialRow,
ChannelOAuthStateRow,
) )
@@ -200,3 +201,26 @@ class TestChannelConnectionRepository:
assert disconnected is False assert disconnected is False
assert (await repo.list_connections("alice"))[0]["status"] == "connected" assert (await repo.list_connections("alice"))[0]["status"] == "connected"
@pytest.mark.anyio
async def test_consume_oauth_state_deletes_expired_states(self, repo):
now = datetime.now(UTC)
await repo.create_oauth_state(
owner_user_id="alice",
provider="slack",
state="expired-state",
expires_at=now - timedelta(minutes=1),
)
await repo.create_oauth_state(
owner_user_id="alice",
provider="slack",
state="active-state",
expires_at=now + timedelta(minutes=5),
)
consumed = await repo.consume_oauth_state(provider="slack", state="expired-state", now=now)
assert consumed is None
async with repo.session_factory() as session:
states = (await session.execute(select(ChannelOAuthStateRow))).scalars().all()
assert [state.state_hash for state in states] == [repo.hash_state("active-state")]