mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 09:55:59 +00:00
Address IM channel review comments
This commit is contained in:
@@ -20,6 +20,17 @@ KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset(
|
||||
)
|
||||
|
||||
|
||||
def extract_connect_code(text: str) -> str | None:
|
||||
"""Extract the one-time channel binding code from a connect command."""
|
||||
parts = text.strip().split()
|
||||
if len(parts) < 2:
|
||||
return None
|
||||
command = parts[0].lower()
|
||||
if command in {"/connect", "connect"}:
|
||||
return parts[1]
|
||||
return None
|
||||
|
||||
|
||||
def is_known_channel_command(text: str) -> bool:
|
||||
"""Return whether text starts with a registered channel control command."""
|
||||
if not text.startswith("/"):
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
"""Helpers for attaching persisted channel connection ownership to inbound messages."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from app.channels.message_bus import InboundMessage
|
||||
|
||||
|
||||
async def attach_connection_identity(
|
||||
inbound: InboundMessage,
|
||||
*,
|
||||
repo: Any,
|
||||
provider: str,
|
||||
workspace_id: str | None,
|
||||
fallback_without_workspace: bool = False,
|
||||
) -> InboundMessage:
|
||||
"""Attach connection metadata to an inbound message when a persisted binding exists."""
|
||||
if repo is None:
|
||||
return inbound
|
||||
|
||||
workspace_candidates: list[str | None] = []
|
||||
if workspace_id:
|
||||
workspace_candidates.append(workspace_id)
|
||||
if fallback_without_workspace:
|
||||
workspace_candidates.append(None)
|
||||
if not workspace_candidates:
|
||||
return inbound
|
||||
|
||||
for candidate in workspace_candidates:
|
||||
connection = await repo.find_connection_by_external_identity(
|
||||
provider=provider,
|
||||
external_account_id=inbound.user_id,
|
||||
workspace_id=candidate,
|
||||
)
|
||||
if connection is None:
|
||||
continue
|
||||
|
||||
inbound.connection_id = connection["id"]
|
||||
inbound.owner_user_id = connection["owner_user_id"]
|
||||
inbound.workspace_id = connection.get("workspace_id")
|
||||
return inbound
|
||||
|
||||
return inbound
|
||||
@@ -10,7 +10,8 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -18,16 +19,6 @@ logger = logging.getLogger(__name__)
|
||||
_DISCORD_MAX_MESSAGE_LEN = 2000
|
||||
|
||||
|
||||
def _extract_connect_code(text: str) -> str | None:
|
||||
parts = text.strip().split()
|
||||
if len(parts) < 2:
|
||||
return None
|
||||
command = parts[0].lower()
|
||||
if command in {"/connect", "connect"}:
|
||||
return parts[1]
|
||||
return None
|
||||
|
||||
|
||||
class DiscordChannel(Channel):
|
||||
"""Discord bot channel.
|
||||
|
||||
@@ -298,7 +289,7 @@ class DiscordChannel(Channel):
|
||||
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
|
||||
# Don't return early if text is empty — still process the mention (e.g., create thread)
|
||||
|
||||
connect_code = _extract_connect_code(text)
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and await self._bind_connection_from_connect_code(message, connect_code):
|
||||
return
|
||||
|
||||
@@ -454,29 +445,13 @@ class DiscordChannel(Channel):
|
||||
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage, guild_id: str | None = None) -> InboundMessage:
|
||||
if self._connection_repo is None:
|
||||
return inbound
|
||||
|
||||
connection = None
|
||||
if guild_id:
|
||||
connection = await self._connection_repo.find_connection_by_external_identity(
|
||||
provider="discord",
|
||||
external_account_id=inbound.user_id,
|
||||
workspace_id=guild_id,
|
||||
)
|
||||
if connection is None:
|
||||
connection = await self._connection_repo.find_connection_by_external_identity(
|
||||
provider="discord",
|
||||
external_account_id=inbound.user_id,
|
||||
workspace_id=None,
|
||||
)
|
||||
if connection is None:
|
||||
return inbound
|
||||
|
||||
inbound.connection_id = connection["id"]
|
||||
inbound.owner_user_id = connection["owner_user_id"]
|
||||
inbound.workspace_id = connection.get("workspace_id")
|
||||
return inbound
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="discord",
|
||||
workspace_id=guild_id,
|
||||
fallback_without_workspace=True,
|
||||
)
|
||||
|
||||
async def _bind_connection_from_connect_code(self, message, code: str) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
|
||||
@@ -9,7 +9,8 @@ from typing import Any
|
||||
from markdown_to_mrkdwn import SlackMarkdownConverter
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -47,16 +48,6 @@ def _strip_leading_slack_bot_mention(text: str, bot_user_id: str | None) -> str:
|
||||
return text[end + 1 :].lstrip()
|
||||
|
||||
|
||||
def _extract_connect_code(text: str) -> str | None:
|
||||
parts = text.strip().split()
|
||||
if len(parts) < 2:
|
||||
return None
|
||||
command = parts[0].lower()
|
||||
if command in {"/connect", "connect"}:
|
||||
return parts[1]
|
||||
return None
|
||||
|
||||
|
||||
class SlackChannel(Channel):
|
||||
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
|
||||
|
||||
@@ -325,7 +316,7 @@ class SlackChannel(Channel):
|
||||
if not text:
|
||||
return
|
||||
|
||||
connect_code = _extract_connect_code(text)
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code:
|
||||
if self._loop and self._loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
@@ -373,25 +364,13 @@ class SlackChannel(Channel):
|
||||
await self.bus.publish_inbound(inbound)
|
||||
|
||||
async def _attach_connection_identity(self, inbound, *, team_id: str | None = None):
|
||||
if self._connection_repo is None:
|
||||
return inbound
|
||||
|
||||
workspace_id = str(team_id or inbound.metadata.get("team_id") or "")
|
||||
if not workspace_id:
|
||||
return inbound
|
||||
|
||||
connection = await self._connection_repo.find_connection_by_external_identity(
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="slack",
|
||||
external_account_id=inbound.user_id,
|
||||
workspace_id=workspace_id,
|
||||
)
|
||||
if connection is None:
|
||||
return inbound
|
||||
|
||||
inbound.connection_id = connection["id"]
|
||||
inbound.owner_user_id = connection["owner_user_id"]
|
||||
inbound.workspace_id = connection.get("workspace_id")
|
||||
return inbound
|
||||
|
||||
async def _bind_connection_from_connect_code(self, *, event: dict, team_id: str, code: str) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
|
||||
@@ -8,6 +8,7 @@ import threading
|
||||
from typing import Any
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -295,21 +296,12 @@ class TelegramChannel(Channel):
|
||||
return True
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||
if self._connection_repo is None:
|
||||
return inbound
|
||||
|
||||
connection = await self._connection_repo.find_connection_by_external_identity(
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="telegram",
|
||||
external_account_id=inbound.user_id,
|
||||
workspace_id=inbound.chat_id,
|
||||
)
|
||||
if connection is None:
|
||||
return inbound
|
||||
|
||||
inbound.connection_id = connection["id"]
|
||||
inbound.owner_user_id = connection["owner_user_id"]
|
||||
inbound.workspace_id = connection.get("workspace_id")
|
||||
return inbound
|
||||
|
||||
def _get_bot_username(self, context) -> str | None:
|
||||
bot = getattr(context, "bot", None)
|
||||
|
||||
@@ -203,6 +203,27 @@ def _connect_url(config: ChannelConnectionsConfig, provider: str, code: str) ->
|
||||
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||
|
||||
|
||||
def _connection_updated_at(connection: dict[str, Any]) -> datetime:
|
||||
value = connection.get("updated_at")
|
||||
if isinstance(value, datetime):
|
||||
return value if value.tzinfo is not None else value.replace(tzinfo=UTC)
|
||||
if isinstance(value, str) and value:
|
||||
try:
|
||||
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
except ValueError:
|
||||
pass
|
||||
return datetime.min.replace(tzinfo=UTC)
|
||||
|
||||
|
||||
def _newest_connection_by_provider(connections: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||
by_provider: dict[str, dict[str, Any]] = {}
|
||||
for item in connections:
|
||||
existing = by_provider.get(item["provider"])
|
||||
if existing is None or _connection_updated_at(item) > _connection_updated_at(existing):
|
||||
by_provider[item["provider"]] = item
|
||||
return by_provider
|
||||
|
||||
|
||||
@router.get("/providers", response_model=ChannelProvidersResponse)
|
||||
async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
||||
config = _get_channel_connections_config(request)
|
||||
@@ -216,9 +237,7 @@ async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
||||
raise
|
||||
owner_user_id = _get_user_id(request)
|
||||
connections = await repo.list_connections(owner_user_id) if repo is not None else []
|
||||
by_provider: dict[str, dict[str, Any]] = {}
|
||||
for item in connections:
|
||||
by_provider.setdefault(item["provider"], item)
|
||||
by_provider = _newest_connection_by_provider(connections)
|
||||
|
||||
providers: list[ChannelProviderResponse] = []
|
||||
for provider, meta in _PROVIDER_META.items():
|
||||
|
||||
@@ -10,7 +10,7 @@ from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.channel_connections.model import (
|
||||
@@ -257,11 +257,14 @@ class ChannelConnectionRepository:
|
||||
) -> dict[str, Any] | None:
|
||||
current_time = now or datetime.now(UTC)
|
||||
async with self.session_factory() as session:
|
||||
await session.execute(delete(ChannelOAuthStateRow).where(ChannelOAuthStateRow.expires_at < current_time))
|
||||
row = await session.get(ChannelOAuthStateRow, self.hash_state(state))
|
||||
if row is None or row.provider != provider or row.consumed_at is not None:
|
||||
await session.commit()
|
||||
return None
|
||||
expires_at = self._coerce_datetime(row.expires_at)
|
||||
if expires_at is not None and expires_at < current_time:
|
||||
await session.commit()
|
||||
return None
|
||||
|
||||
row.consumed_at = current_time
|
||||
|
||||
@@ -12,6 +12,7 @@ from deerflow.persistence.channel_connections import (
|
||||
ChannelConnectionRow,
|
||||
ChannelCredentialCipher,
|
||||
ChannelCredentialRow,
|
||||
ChannelOAuthStateRow,
|
||||
)
|
||||
|
||||
|
||||
@@ -200,3 +201,26 @@ class TestChannelConnectionRepository:
|
||||
|
||||
assert disconnected is False
|
||||
assert (await repo.list_connections("alice"))[0]["status"] == "connected"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_consume_oauth_state_deletes_expired_states(self, repo):
|
||||
now = datetime.now(UTC)
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
state="expired-state",
|
||||
expires_at=now - timedelta(minutes=1),
|
||||
)
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
state="active-state",
|
||||
expires_at=now + timedelta(minutes=5),
|
||||
)
|
||||
|
||||
consumed = await repo.consume_oauth_state(provider="slack", state="expired-state", now=now)
|
||||
|
||||
assert consumed is None
|
||||
async with repo.session_factory() as session:
|
||||
states = (await session.execute(select(ChannelOAuthStateRow))).scalars().all()
|
||||
assert [state.state_hash for state in states] == [repo.hash_state("active-state")]
|
||||
|
||||
Reference in New Issue
Block a user