mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 09:55:59 +00:00
Address IM channel review comments
This commit is contained in:
@@ -20,6 +20,17 @@ KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_connect_code(text: str) -> str | None:
|
||||||
|
"""Extract the one-time channel binding code from a connect command."""
|
||||||
|
parts = text.strip().split()
|
||||||
|
if len(parts) < 2:
|
||||||
|
return None
|
||||||
|
command = parts[0].lower()
|
||||||
|
if command in {"/connect", "connect"}:
|
||||||
|
return parts[1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def is_known_channel_command(text: str) -> bool:
|
def is_known_channel_command(text: str) -> bool:
|
||||||
"""Return whether text starts with a registered channel control command."""
|
"""Return whether text starts with a registered channel control command."""
|
||||||
if not text.startswith("/"):
|
if not text.startswith("/"):
|
||||||
|
|||||||
@@ -0,0 +1,44 @@
|
|||||||
|
"""Helpers for attaching persisted channel connection ownership to inbound messages."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.channels.message_bus import InboundMessage
|
||||||
|
|
||||||
|
|
||||||
|
async def attach_connection_identity(
|
||||||
|
inbound: InboundMessage,
|
||||||
|
*,
|
||||||
|
repo: Any,
|
||||||
|
provider: str,
|
||||||
|
workspace_id: str | None,
|
||||||
|
fallback_without_workspace: bool = False,
|
||||||
|
) -> InboundMessage:
|
||||||
|
"""Attach connection metadata to an inbound message when a persisted binding exists."""
|
||||||
|
if repo is None:
|
||||||
|
return inbound
|
||||||
|
|
||||||
|
workspace_candidates: list[str | None] = []
|
||||||
|
if workspace_id:
|
||||||
|
workspace_candidates.append(workspace_id)
|
||||||
|
if fallback_without_workspace:
|
||||||
|
workspace_candidates.append(None)
|
||||||
|
if not workspace_candidates:
|
||||||
|
return inbound
|
||||||
|
|
||||||
|
for candidate in workspace_candidates:
|
||||||
|
connection = await repo.find_connection_by_external_identity(
|
||||||
|
provider=provider,
|
||||||
|
external_account_id=inbound.user_id,
|
||||||
|
workspace_id=candidate,
|
||||||
|
)
|
||||||
|
if connection is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
inbound.connection_id = connection["id"]
|
||||||
|
inbound.owner_user_id = connection["owner_user_id"]
|
||||||
|
inbound.workspace_id = connection.get("workspace_id")
|
||||||
|
return inbound
|
||||||
|
|
||||||
|
return inbound
|
||||||
@@ -10,7 +10,8 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.commands import is_known_channel_command
|
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||||
|
from app.channels.connection_identity import attach_connection_identity
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -18,16 +19,6 @@ logger = logging.getLogger(__name__)
|
|||||||
_DISCORD_MAX_MESSAGE_LEN = 2000
|
_DISCORD_MAX_MESSAGE_LEN = 2000
|
||||||
|
|
||||||
|
|
||||||
def _extract_connect_code(text: str) -> str | None:
|
|
||||||
parts = text.strip().split()
|
|
||||||
if len(parts) < 2:
|
|
||||||
return None
|
|
||||||
command = parts[0].lower()
|
|
||||||
if command in {"/connect", "connect"}:
|
|
||||||
return parts[1]
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class DiscordChannel(Channel):
|
class DiscordChannel(Channel):
|
||||||
"""Discord bot channel.
|
"""Discord bot channel.
|
||||||
|
|
||||||
@@ -298,7 +289,7 @@ class DiscordChannel(Channel):
|
|||||||
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
|
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
|
||||||
# Don't return early if text is empty — still process the mention (e.g., create thread)
|
# Don't return early if text is empty — still process the mention (e.g., create thread)
|
||||||
|
|
||||||
connect_code = _extract_connect_code(text)
|
connect_code = extract_connect_code(text)
|
||||||
if connect_code and await self._bind_connection_from_connect_code(message, connect_code):
|
if connect_code and await self._bind_connection_from_connect_code(message, connect_code):
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -454,29 +445,13 @@ class DiscordChannel(Channel):
|
|||||||
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
||||||
|
|
||||||
async def _attach_connection_identity(self, inbound: InboundMessage, guild_id: str | None = None) -> InboundMessage:
|
async def _attach_connection_identity(self, inbound: InboundMessage, guild_id: str | None = None) -> InboundMessage:
|
||||||
if self._connection_repo is None:
|
return await attach_connection_identity(
|
||||||
return inbound
|
inbound,
|
||||||
|
repo=self._connection_repo,
|
||||||
connection = None
|
|
||||||
if guild_id:
|
|
||||||
connection = await self._connection_repo.find_connection_by_external_identity(
|
|
||||||
provider="discord",
|
provider="discord",
|
||||||
external_account_id=inbound.user_id,
|
|
||||||
workspace_id=guild_id,
|
workspace_id=guild_id,
|
||||||
|
fallback_without_workspace=True,
|
||||||
)
|
)
|
||||||
if connection is None:
|
|
||||||
connection = await self._connection_repo.find_connection_by_external_identity(
|
|
||||||
provider="discord",
|
|
||||||
external_account_id=inbound.user_id,
|
|
||||||
workspace_id=None,
|
|
||||||
)
|
|
||||||
if connection is None:
|
|
||||||
return inbound
|
|
||||||
|
|
||||||
inbound.connection_id = connection["id"]
|
|
||||||
inbound.owner_user_id = connection["owner_user_id"]
|
|
||||||
inbound.workspace_id = connection.get("workspace_id")
|
|
||||||
return inbound
|
|
||||||
|
|
||||||
async def _bind_connection_from_connect_code(self, message, code: str) -> bool:
|
async def _bind_connection_from_connect_code(self, message, code: str) -> bool:
|
||||||
if self._connection_repo is None or not code:
|
if self._connection_repo is None or not code:
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ from typing import Any
|
|||||||
from markdown_to_mrkdwn import SlackMarkdownConverter
|
from markdown_to_mrkdwn import SlackMarkdownConverter
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.commands import is_known_channel_command
|
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||||
|
from app.channels.connection_identity import attach_connection_identity
|
||||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -47,16 +48,6 @@ def _strip_leading_slack_bot_mention(text: str, bot_user_id: str | None) -> str:
|
|||||||
return text[end + 1 :].lstrip()
|
return text[end + 1 :].lstrip()
|
||||||
|
|
||||||
|
|
||||||
def _extract_connect_code(text: str) -> str | None:
|
|
||||||
parts = text.strip().split()
|
|
||||||
if len(parts) < 2:
|
|
||||||
return None
|
|
||||||
command = parts[0].lower()
|
|
||||||
if command in {"/connect", "connect"}:
|
|
||||||
return parts[1]
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class SlackChannel(Channel):
|
class SlackChannel(Channel):
|
||||||
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
|
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
|
||||||
|
|
||||||
@@ -325,7 +316,7 @@ class SlackChannel(Channel):
|
|||||||
if not text:
|
if not text:
|
||||||
return
|
return
|
||||||
|
|
||||||
connect_code = _extract_connect_code(text)
|
connect_code = extract_connect_code(text)
|
||||||
if connect_code:
|
if connect_code:
|
||||||
if self._loop and self._loop.is_running():
|
if self._loop and self._loop.is_running():
|
||||||
asyncio.run_coroutine_threadsafe(
|
asyncio.run_coroutine_threadsafe(
|
||||||
@@ -373,25 +364,13 @@ class SlackChannel(Channel):
|
|||||||
await self.bus.publish_inbound(inbound)
|
await self.bus.publish_inbound(inbound)
|
||||||
|
|
||||||
async def _attach_connection_identity(self, inbound, *, team_id: str | None = None):
|
async def _attach_connection_identity(self, inbound, *, team_id: str | None = None):
|
||||||
if self._connection_repo is None:
|
|
||||||
return inbound
|
|
||||||
|
|
||||||
workspace_id = str(team_id or inbound.metadata.get("team_id") or "")
|
workspace_id = str(team_id or inbound.metadata.get("team_id") or "")
|
||||||
if not workspace_id:
|
return await attach_connection_identity(
|
||||||
return inbound
|
inbound,
|
||||||
|
repo=self._connection_repo,
|
||||||
connection = await self._connection_repo.find_connection_by_external_identity(
|
|
||||||
provider="slack",
|
provider="slack",
|
||||||
external_account_id=inbound.user_id,
|
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
)
|
)
|
||||||
if connection is None:
|
|
||||||
return inbound
|
|
||||||
|
|
||||||
inbound.connection_id = connection["id"]
|
|
||||||
inbound.owner_user_id = connection["owner_user_id"]
|
|
||||||
inbound.workspace_id = connection.get("workspace_id")
|
|
||||||
return inbound
|
|
||||||
|
|
||||||
async def _bind_connection_from_connect_code(self, *, event: dict, team_id: str, code: str) -> bool:
|
async def _bind_connection_from_connect_code(self, *, event: dict, team_id: str, code: str) -> bool:
|
||||||
if self._connection_repo is None or not code:
|
if self._connection_repo is None or not code:
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import threading
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
|
from app.channels.connection_identity import attach_connection_identity
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -295,21 +296,12 @@ class TelegramChannel(Channel):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||||
if self._connection_repo is None:
|
return await attach_connection_identity(
|
||||||
return inbound
|
inbound,
|
||||||
|
repo=self._connection_repo,
|
||||||
connection = await self._connection_repo.find_connection_by_external_identity(
|
|
||||||
provider="telegram",
|
provider="telegram",
|
||||||
external_account_id=inbound.user_id,
|
|
||||||
workspace_id=inbound.chat_id,
|
workspace_id=inbound.chat_id,
|
||||||
)
|
)
|
||||||
if connection is None:
|
|
||||||
return inbound
|
|
||||||
|
|
||||||
inbound.connection_id = connection["id"]
|
|
||||||
inbound.owner_user_id = connection["owner_user_id"]
|
|
||||||
inbound.workspace_id = connection.get("workspace_id")
|
|
||||||
return inbound
|
|
||||||
|
|
||||||
def _get_bot_username(self, context) -> str | None:
|
def _get_bot_username(self, context) -> str | None:
|
||||||
bot = getattr(context, "bot", None)
|
bot = getattr(context, "bot", None)
|
||||||
|
|||||||
@@ -203,6 +203,27 @@ def _connect_url(config: ChannelConnectionsConfig, provider: str, code: str) ->
|
|||||||
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||||
|
|
||||||
|
|
||||||
|
def _connection_updated_at(connection: dict[str, Any]) -> datetime:
|
||||||
|
value = connection.get("updated_at")
|
||||||
|
if isinstance(value, datetime):
|
||||||
|
return value if value.tzinfo is not None else value.replace(tzinfo=UTC)
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
try:
|
||||||
|
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return datetime.min.replace(tzinfo=UTC)
|
||||||
|
|
||||||
|
|
||||||
|
def _newest_connection_by_provider(connections: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||||
|
by_provider: dict[str, dict[str, Any]] = {}
|
||||||
|
for item in connections:
|
||||||
|
existing = by_provider.get(item["provider"])
|
||||||
|
if existing is None or _connection_updated_at(item) > _connection_updated_at(existing):
|
||||||
|
by_provider[item["provider"]] = item
|
||||||
|
return by_provider
|
||||||
|
|
||||||
|
|
||||||
@router.get("/providers", response_model=ChannelProvidersResponse)
|
@router.get("/providers", response_model=ChannelProvidersResponse)
|
||||||
async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
||||||
config = _get_channel_connections_config(request)
|
config = _get_channel_connections_config(request)
|
||||||
@@ -216,9 +237,7 @@ async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
|||||||
raise
|
raise
|
||||||
owner_user_id = _get_user_id(request)
|
owner_user_id = _get_user_id(request)
|
||||||
connections = await repo.list_connections(owner_user_id) if repo is not None else []
|
connections = await repo.list_connections(owner_user_id) if repo is not None else []
|
||||||
by_provider: dict[str, dict[str, Any]] = {}
|
by_provider = _newest_connection_by_provider(connections)
|
||||||
for item in connections:
|
|
||||||
by_provider.setdefault(item["provider"], item)
|
|
||||||
|
|
||||||
providers: list[ChannelProviderResponse] = []
|
providers: list[ChannelProviderResponse] = []
|
||||||
for provider, meta in _PROVIDER_META.items():
|
for provider, meta in _PROVIDER_META.items():
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from datetime import UTC, datetime
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
from sqlalchemy import select
|
from sqlalchemy import delete, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
from deerflow.persistence.channel_connections.model import (
|
from deerflow.persistence.channel_connections.model import (
|
||||||
@@ -257,11 +257,14 @@ class ChannelConnectionRepository:
|
|||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
current_time = now or datetime.now(UTC)
|
current_time = now or datetime.now(UTC)
|
||||||
async with self.session_factory() as session:
|
async with self.session_factory() as session:
|
||||||
|
await session.execute(delete(ChannelOAuthStateRow).where(ChannelOAuthStateRow.expires_at < current_time))
|
||||||
row = await session.get(ChannelOAuthStateRow, self.hash_state(state))
|
row = await session.get(ChannelOAuthStateRow, self.hash_state(state))
|
||||||
if row is None or row.provider != provider or row.consumed_at is not None:
|
if row is None or row.provider != provider or row.consumed_at is not None:
|
||||||
|
await session.commit()
|
||||||
return None
|
return None
|
||||||
expires_at = self._coerce_datetime(row.expires_at)
|
expires_at = self._coerce_datetime(row.expires_at)
|
||||||
if expires_at is not None and expires_at < current_time:
|
if expires_at is not None and expires_at < current_time:
|
||||||
|
await session.commit()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
row.consumed_at = current_time
|
row.consumed_at = current_time
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from deerflow.persistence.channel_connections import (
|
|||||||
ChannelConnectionRow,
|
ChannelConnectionRow,
|
||||||
ChannelCredentialCipher,
|
ChannelCredentialCipher,
|
||||||
ChannelCredentialRow,
|
ChannelCredentialRow,
|
||||||
|
ChannelOAuthStateRow,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -200,3 +201,26 @@ class TestChannelConnectionRepository:
|
|||||||
|
|
||||||
assert disconnected is False
|
assert disconnected is False
|
||||||
assert (await repo.list_connections("alice"))[0]["status"] == "connected"
|
assert (await repo.list_connections("alice"))[0]["status"] == "connected"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_consume_oauth_state_deletes_expired_states(self, repo):
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
await repo.create_oauth_state(
|
||||||
|
owner_user_id="alice",
|
||||||
|
provider="slack",
|
||||||
|
state="expired-state",
|
||||||
|
expires_at=now - timedelta(minutes=1),
|
||||||
|
)
|
||||||
|
await repo.create_oauth_state(
|
||||||
|
owner_user_id="alice",
|
||||||
|
provider="slack",
|
||||||
|
state="active-state",
|
||||||
|
expires_at=now + timedelta(minutes=5),
|
||||||
|
)
|
||||||
|
|
||||||
|
consumed = await repo.consume_oauth_state(provider="slack", state="expired-state", now=now)
|
||||||
|
|
||||||
|
assert consumed is None
|
||||||
|
async with repo.session_factory() as session:
|
||||||
|
states = (await session.execute(select(ChannelOAuthStateRow))).scalars().all()
|
||||||
|
assert [state.state_hash for state in states] == [repo.hash_state("active-state")]
|
||||||
|
|||||||
Reference in New Issue
Block a user