mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-18 05:25:57 +00:00
fix(channels): make channel connect flow deterministic (#3582)
* fix(channels): make channel connect flow deterministic * make format * fix(channels): apply connect-code before allowed_users on telegram and wechat The bind-bootstrap reorder shipped for slack/dingtalk only. Telegram and WeChat still gate _check_user/allowed_users before connect-code dispatch, so a newly allowlisted-but-unbound user is silently rejected when binding via the browser deep-link / connect-code flow — the same deadlock the PR fixes. - telegram: consume the /start deep-link token before the allowed_users gate. - wechat: handle the /connect code before the allowed_users gate, and defer inbound file extraction + context-token tracking past the gate so blocked senders no longer trigger CDN downloads or token bookkeeping. Adds regression tests for both adapters mirroring the slack/dingtalk coverage. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * fix(channels): enforce single-active-owner invariant at the DB layer _revoke_other_active_owners did a SELECT-then-UPDATE in app code with no row lock or constraint covering active rows. Under READ COMMITTED, two concurrent connect-code consumes for the same (provider, external_account_id, workspace_id) from different owners could each observe "no other active owner" and both commit a connected row, leaving find_connection_by_external_identity nondeterministic. - Add a partial unique index on (provider, external_account_id, workspace_id) WHERE status != 'revoked' (portable to SQLite >= 3.8.0 and PostgreSQL) so the database guarantees at most one non-revoked row per external identity. - Reorder upsert_connection to revoke other owners' active rows before the new connected row is flushed (so the index is satisfied at commit), wrapped in a bounded rollback-and-retry loop. A losing concurrent writer now retries against the now-visible state instead of committing a duplicate. Adds DB-constraint, revoked-slot-reuse, and concurrent-upsert regression tests. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * fix(channels): harden connect-status polling primitive pollChannelConnectionUntilResolved was a free-floating recursive setTimeout started from onSuccess with no cancellation, no per-provider dedup, a redundant second endpoint per tick, and an unbounded loop on a non-finite expires_in. - Extract a framework-agnostic, cancellable poller (connect-poll.ts) that polls only listChannelConnections() and invalidates the providers query once when the bind resolves, instead of fetching both endpoints every tick. - Guard expires_in with a finite check + default window so undefined/NaN can no longer produce a poll loop that runs until the page closes. - Track one active poll handle per provider in useConnectChannelProvider via a ref Map: a new connect cancels the prior poll for that provider, and a useEffect cleanup cancels all polls on unmount. Adds unit tests for resolve-and-stop, cancellation, and non-finite-expiry. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * fix(channels): stop leaking blocked-sender content in DingTalk INFO log; document bind semantics Moving the allowed_users gate past _extract_text meant the parsed-message INFO log (text=%r, first 100 chars) fired for senders that allowed_users would have rejected, defeating the filter's noise/privacy role. Move that log to after the allowed_users gate so blocked senders' message text never reaches INFO logs. Also document the two operator-relevant semantic changes in backend/CLAUDE.md: connect-code dispatch runs before allowed_users (so allowed_users is no longer a bind-time defense; the model relies on code confidentiality + 600s TTL + one-time consumption), and the single-active-owner-per-external-identity transfer semantics now backed by the partial unique index. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * docs(channels): note connect-code-vs-allowlist and ownership transfer in operator guide Mirror the backend/CLAUDE.md notes in the operator-facing IM_CHANNEL_CONNECTIONS.md: connect codes are consumed before allowed_users (so a not-yet-allowlisted user can still complete a first bind, and allowed_users is not a bind-time defense), and an external identity has at most one active owner with last-bind-wins transfer enforced at the DB layer. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * refactor(channels): lift connect-code dispatch into Channel base class Each adapter duplicated the ordering-sensitive boilerplate of extracting a /connect code and guarding on the connection repo before its allowed_users gate. The duplication is what let telegram/wechat drift and keep the gate ahead of the bind. Centralize it: - Move `_connection_repo` onto Channel.__init__ (removing 7 duplicate assignments). - Add Channel._pending_connect_code(text), which guards on the repo and extracts the code, documenting that adapters MUST consult it before authorization so a browser-initiated bind can bootstrap a not-yet-authorized identity. - Route slack, discord, feishu, dingtalk, wechat, and wecom through the helper. This also fixes a latent inconsistency where slack dispatched a bind even when no connection repo was configured. Pure refactor — the full channel suite stays green; adds a direct unit test for the base helper's contract. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * make format * fix(channels): redact DingTalk parsed-message INFO log content Log text_len instead of the first 100 chars of message text, so message content never reaches INFO logs (the after-gate move already keeps blocked senders out entirely). This takes over the redaction from #3584 so only this PR touches dingtalk.py, letting the two PRs merge in any order conflict-free. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -420,6 +420,8 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, Discord, DingTalk
|
||||
- Provider-level `connection_status` reflects the user's newest connection row. With no binding it is `not_connected`, except in auth-disabled local mode where a configured running channel reports `connected` because all channel messages already route to the default user.
|
||||
- Slack replies use the configured operator bot token from `channels.slack` unless per-connection credentials are present; unreadable or corrupt stored credentials are treated as unavailable.
|
||||
- Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom workers resolve incoming platform identities to connection records before reaching `ChannelManager`.
|
||||
- **Connect-code ordering vs `allowed_users`**: inbound workers consume a valid `/connect <code>` (or Telegram `/start <code>`) **before** applying the `allowed_users` filter, so a newly allowlisted-but-unbound user can bootstrap their first bind via the browser flow. Consequence: `allowed_users` is **not** a bind-time defense — any sender who possesses a valid code can consume it (not only allowlisted users). The bind security model rests on the code's confidentiality: `secrets.token_urlsafe(16)`, 600 s TTL, one-time `consume_oauth_state`, and codes surfaced only in the initiating browser (never echoed to chat). `allowed_users` still gates ordinary (non-bind) messages.
|
||||
- **Single-active-owner transfer semantics**: an external identity is keyed by `(provider, external_account_id, workspace_id)`. The latest successful bind wins — `upsert_connection` revokes other owners' active rows for the same identity (ownership transfer). This invariant is enforced at the DB layer by the partial unique index `uq_channel_connection_active_identity` (`WHERE status != 'revoked'`), so concurrent connects from different owners cannot both end `connected`; the losing writer retries against the now-visible state. `find_connection_by_external_identity` therefore resolves deterministically.
|
||||
- See `backend/docs/IM_CHANNEL_CONNECTIONS.md` for provider setup and operational notes.
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from collections.abc import Awaitable, Callable
|
||||
from concurrent.futures import CancelledError as FutureCancelledError
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from app.channels.commands import extract_connect_code
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -31,6 +32,7 @@ class Channel(ABC):
|
||||
self.bus = bus
|
||||
self.config = config
|
||||
self._running = False
|
||||
self._connection_repo: Any = config.get("connection_repo")
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
@@ -117,6 +119,19 @@ class Channel(ABC):
|
||||
if exc:
|
||||
logger.error("[%s] %s failed for msg_id=%s: %s", self.name, name, msg_id, exc)
|
||||
|
||||
def _pending_connect_code(self, text: str) -> str | None:
|
||||
"""Return the one-time bind code if *text* is a ``/connect <code>`` command
|
||||
and channel connections are configured, else ``None``.
|
||||
|
||||
Adapters MUST consult this **before** applying their ``allowed_users`` /
|
||||
``_check_user`` gate, so a browser-initiated bind can bootstrap an external
|
||||
identity that the platform bot has never seen and is therefore not yet
|
||||
authorized. (Telegram uses its deep-link ``/start <token>`` flow instead.)
|
||||
"""
|
||||
if self._connection_repo is None:
|
||||
return None
|
||||
return extract_connect_code(text)
|
||||
|
||||
def _make_inbound(
|
||||
self,
|
||||
chat_id: str,
|
||||
|
||||
@@ -14,7 +14,7 @@ from typing import Any
|
||||
import httpx
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
@@ -137,7 +137,6 @@ class DingTalkChannel(Channel):
|
||||
self._incoming_messages: dict[str, Any] = {}
|
||||
self._incoming_messages_lock = threading.Lock()
|
||||
self._card_repliers: dict[str, Any] = {}
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
@@ -366,26 +365,13 @@ class DingTalkChannel(Channel):
|
||||
msg_id = message.message_id or ""
|
||||
sender_nick = message.sender_nick or ""
|
||||
|
||||
if self._allowed_users and sender_staff_id not in self._allowed_users:
|
||||
logger.debug("[DingTalk] ignoring message from non-allowed user: %s", sender_staff_id)
|
||||
return
|
||||
|
||||
text = self._extract_text(message)
|
||||
if not text:
|
||||
logger.info("[DingTalk] empty text, ignoring message")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"[DingTalk] parsed message: conv_type=%s, msg_id=%s, sender=%s(%s), text=%r",
|
||||
conversation_type,
|
||||
msg_id,
|
||||
sender_staff_id,
|
||||
sender_nick,
|
||||
text[:100],
|
||||
)
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and self._connection_repo is not None:
|
||||
connect_code = self._pending_connect_code(text)
|
||||
if connect_code:
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
self._bind_connection_from_connect_code(
|
||||
@@ -402,6 +388,22 @@ class DingTalkChannel(Channel):
|
||||
logger.warning("[DingTalk] main loop not running, cannot bind channel connection")
|
||||
return
|
||||
|
||||
if self._allowed_users and sender_staff_id not in self._allowed_users:
|
||||
logger.debug("[DingTalk] ignoring message from non-allowed user: %s", sender_staff_id)
|
||||
return
|
||||
|
||||
# Log only metadata (length, not content) so message text never reaches
|
||||
# INFO logs, and only after the allowed_users gate so blocked senders are
|
||||
# not logged at all.
|
||||
logger.info(
|
||||
"[DingTalk] parsed message: conv_type=%s, msg_id=%s, sender=%s(%s), text_len=%d",
|
||||
conversation_type,
|
||||
msg_id,
|
||||
sender_staff_id,
|
||||
sender_nick,
|
||||
len(text or ""),
|
||||
)
|
||||
|
||||
if _is_dingtalk_command(text):
|
||||
msg_type = InboundMessageType.COMMAND
|
||||
else:
|
||||
|
||||
@@ -10,7 +10,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
@@ -71,7 +71,6 @@ class DiscordChannel(Channel):
|
||||
self._discord_loop: asyncio.AbstractEventLoop | None = None
|
||||
self._main_loop: asyncio.AbstractEventLoop | None = None
|
||||
self._discord_module = None
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
@@ -293,7 +292,7 @@ class DiscordChannel(Channel):
|
||||
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
|
||||
# Don't return early if text is empty — still process the mention (e.g., create thread)
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
connect_code = self._pending_connect_code(text)
|
||||
if connect_code and await self._bind_connection_from_connect_code(message, connect_code):
|
||||
return
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import time
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import (
|
||||
PENDING_CLARIFICATION_METADATA_KEY,
|
||||
@@ -72,7 +72,6 @@ class FeishuChannel(Channel):
|
||||
self._CreateImageRequestBody = None
|
||||
self._GetMessageResourceRequest = None
|
||||
self._thread_lock = threading.Lock()
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
@staticmethod
|
||||
def _non_empty_str(value: Any) -> str | None:
|
||||
@@ -851,8 +850,8 @@ class FeishuChannel(Channel):
|
||||
logger.info("[Feishu] empty text, ignoring message")
|
||||
return
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and self._connection_repo is not None:
|
||||
connect_code = self._pending_connect_code(text)
|
||||
if connect_code:
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
self._bind_connection_from_connect_code(
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Any
|
||||
from markdown_to_mrkdwn import SlackMarkdownConverter
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
@@ -65,7 +65,6 @@ class SlackChannel(Channel):
|
||||
self._web_client = None
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
self._web_client_factory = config.get("web_client_factory")
|
||||
self._connection_web_clients: dict[str, tuple[str, Any]] = {}
|
||||
configured_bot_user_id = config.get("bot_user_id")
|
||||
@@ -295,18 +294,13 @@ class SlackChannel(Channel):
|
||||
|
||||
user_id = event.get("user", "")
|
||||
|
||||
# Check allowed users
|
||||
if self._allowed_users and user_id not in self._allowed_users:
|
||||
logger.debug("Ignoring message from non-allowed user: %s", user_id)
|
||||
return
|
||||
|
||||
text = event.get("text", "").strip()
|
||||
if event.get("type") == "app_mention":
|
||||
text = _strip_leading_slack_bot_mention(text, self._bot_user_id)
|
||||
if not text:
|
||||
return
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
connect_code = self._pending_connect_code(text)
|
||||
if connect_code:
|
||||
if self._loop and self._loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
@@ -319,6 +313,12 @@ class SlackChannel(Channel):
|
||||
)
|
||||
return
|
||||
|
||||
# Check allowed users after connect-code handling so browser-initiated
|
||||
# binding can bootstrap a new external identity.
|
||||
if self._allowed_users and user_id not in self._allowed_users:
|
||||
logger.debug("Ignoring message from non-allowed user: %s", user_id)
|
||||
return
|
||||
|
||||
channel_id = event.get("channel", "")
|
||||
thread_ts = event.get("thread_ts") or event.get("ts", "")
|
||||
|
||||
|
||||
@@ -52,7 +52,6 @@ class TelegramChannel(Channel):
|
||||
# stream_key ("chat_id:thread_ts") -> state of the in-flight streamed
|
||||
# bot message being edited in place: {"message_id", "last_edit_at", "last_text"}
|
||||
self._stream_messages: dict[str, dict[str, Any]] = {}
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
@@ -463,13 +462,15 @@ class TelegramChannel(Channel):
|
||||
|
||||
async def _cmd_start(self, update, context) -> None:
|
||||
"""Handle /start command."""
|
||||
if not self._check_user(update.effective_user.id):
|
||||
return
|
||||
args = getattr(context, "args", []) if context is not None else []
|
||||
if args:
|
||||
# Handle the deep-link bind token before applying allowed_users so a
|
||||
# browser-initiated bind can bootstrap a new external identity.
|
||||
handled = await self._bind_connection_from_start_token(update, str(args[0]))
|
||||
if handled:
|
||||
return
|
||||
if not self._check_user(update.effective_user.id):
|
||||
return
|
||||
await update.message.reply_text("Welcome to DeerFlow! Send me a message to start a conversation.\nType /help for available commands.")
|
||||
|
||||
async def _process_incoming_with_reply(self, chat_id: str, msg_id: int, inbound: InboundMessage) -> None:
|
||||
|
||||
@@ -22,7 +22,7 @@ from cryptography.hazmat.primitives import padding
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
@@ -254,7 +254,6 @@ class WechatChannel(Channel):
|
||||
self._state_dir = self._resolve_state_dir(config.get("state_dir"))
|
||||
self._cursor_path = self._state_dir / "wechat-getupdates.json" if self._state_dir else None
|
||||
self._auth_path = self._state_dir / "wechat-auth.json" if self._state_dir else None
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
self._load_state()
|
||||
|
||||
async def start(self) -> None:
|
||||
@@ -591,24 +590,16 @@ class WechatChannel(Channel):
|
||||
return
|
||||
|
||||
chat_id = str(raw_message.get("from_user_id") or raw_message.get("ilink_user_id") or "").strip()
|
||||
if not chat_id or not self._check_user(chat_id):
|
||||
if not chat_id:
|
||||
return
|
||||
|
||||
text = self._extract_text(raw_message)
|
||||
files = await self._extract_inbound_files(raw_message)
|
||||
if not text and not files:
|
||||
return
|
||||
|
||||
context_token = str(raw_message.get("context_token") or "").strip()
|
||||
thread_ts = context_token or str(raw_message.get("client_id") or raw_message.get("msg_id") or "").strip() or None
|
||||
|
||||
if context_token:
|
||||
self._context_tokens_by_chat[chat_id] = context_token
|
||||
if thread_ts:
|
||||
self._context_tokens_by_thread[thread_ts] = context_token
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and self._connection_repo is not None:
|
||||
# Handle the connect code before applying allowed_users so a browser-initiated
|
||||
# bind can bootstrap an external identity that is not yet whitelisted.
|
||||
connect_code = self._pending_connect_code(text)
|
||||
if connect_code:
|
||||
handled = await self._bind_connection_from_connect_code(
|
||||
chat_id=chat_id,
|
||||
context_token=context_token,
|
||||
@@ -617,6 +608,20 @@ class WechatChannel(Channel):
|
||||
if handled:
|
||||
return
|
||||
|
||||
if not self._check_user(chat_id):
|
||||
return
|
||||
|
||||
files = await self._extract_inbound_files(raw_message)
|
||||
if not text and not files:
|
||||
return
|
||||
|
||||
thread_ts = context_token or str(raw_message.get("client_id") or raw_message.get("msg_id") or "").strip() or None
|
||||
|
||||
if context_token:
|
||||
self._context_tokens_by_chat[chat_id] = context_token
|
||||
if thread_ts:
|
||||
self._context_tokens_by_thread[thread_ts] = context_token
|
||||
|
||||
inbound = self._make_inbound(
|
||||
chat_id=chat_id,
|
||||
user_id=chat_id,
|
||||
|
||||
@@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import (
|
||||
InboundMessage,
|
||||
@@ -31,7 +31,6 @@ class WeComChannel(Channel):
|
||||
self._ws_frames: dict[str, dict[str, Any]] = {}
|
||||
self._ws_stream_ids: dict[str, str] = {}
|
||||
self._working_message = "Working on it..."
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
@@ -295,8 +294,8 @@ class WeComChannel(Channel):
|
||||
|
||||
user_id = (body.get("from") or {}).get("userid")
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and self._connection_repo is not None:
|
||||
connect_code = self._pending_connect_code(text)
|
||||
if connect_code:
|
||||
handled = await self._bind_connection_from_connect_code(
|
||||
frame=frame,
|
||||
user_id=str(user_id or ""),
|
||||
|
||||
@@ -111,6 +111,8 @@ Feishu/Lark, DingTalk, WeChat, and WeCom:
|
||||
|
||||
Codes use 128 bits of randomness, expire after 10 minutes, and are single-use.
|
||||
|
||||
For providers with an `allowed_users` allowlist (Telegram, Slack, DingTalk, WeChat, …), a valid `/connect <code>` (or Telegram `/start <code>`) is consumed **before** the allowlist is checked. This is intentional: a user who is not yet on the allowlist — and whose platform identity the bot has therefore never seen — can still complete their first browser-initiated bind. After binding, `allowed_users` continues to gate ordinary (non-bind) messages as before.
|
||||
|
||||
## Runtime Model
|
||||
|
||||
Connection records live in SQL tables under `deerflow.persistence.channel_connections`:
|
||||
@@ -126,6 +128,8 @@ Incoming messages that resolve to a connection carry `connection_id`, `owner_use
|
||||
|
||||
- Browser APIs remain authenticated and CSRF-protected.
|
||||
- Connect codes are 128-bit random, short-lived, and single-use.
|
||||
- `allowed_users` is **not** a bind-time defense. Because connect codes are processed before the allowlist (see Connect Flow), anyone who possesses a valid code can consume it — not only allowlisted users. Bind security therefore rests entirely on the code's confidentiality: it is 128-bit random, expires after 10 minutes, is single-use, and is shown only in the initiating user's browser (never echoed back to chat). Treat connect codes like one-time passwords and do not forward them.
|
||||
- An external identity — `(provider, external account, workspace/team/guild)` — has at most one active owner. The most recent successful bind wins: connecting an identity that another DeerFlow user already holds transfers ownership and revokes the previous owner's binding (and its stored credentials). This is enforced at the database layer, so two users racing to bind the same identity cannot both end up connected.
|
||||
- Provider bot tokens remain in `channels.*` and are never returned to the browser.
|
||||
- Stored per-connection credentials are encrypted. If stored credential material cannot be decrypted, DeerFlow treats it as unavailable instead of using corrupt secrets.
|
||||
- This implementation does not add public provider callback or webhook routes.
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import JSON, DateTime, ForeignKey, Index, Integer, String, Text, UniqueConstraint
|
||||
from sqlalchemy import JSON, DateTime, ForeignKey, Index, Integer, String, Text, UniqueConstraint, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
@@ -46,6 +46,20 @@ class ChannelConnectionRow(Base):
|
||||
name="uq_channel_connection_owner_provider_identity",
|
||||
),
|
||||
Index("idx_channel_connections_event_lookup", "provider", "workspace_id", "bot_user_id"),
|
||||
# Enforce the single-active-owner invariant at the database layer: at most
|
||||
# one non-revoked row may exist per external identity. This makes ownership
|
||||
# transfer race-safe (concurrent connects from different owners can no
|
||||
# longer both commit a connected row). Partial unique indexes are
|
||||
# supported by both SQLite (>= 3.8.0) and PostgreSQL.
|
||||
Index(
|
||||
"uq_channel_connection_active_identity",
|
||||
"provider",
|
||||
"external_account_id",
|
||||
"workspace_id",
|
||||
unique=True,
|
||||
sqlite_where=text("status != 'revoked'"),
|
||||
postgresql_where=text("status != 'revoked'"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,12 @@ from deerflow.utils.time import coerce_iso
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Bounded retries for upsert_connection when a concurrent writer commits a
|
||||
# conflicting row first (same owner identity, or the same active external
|
||||
# identity guarded by the partial unique index). Each retry re-reads the
|
||||
# now-visible state, so a small bound converges under realistic contention.
|
||||
_UPSERT_MAX_ATTEMPTS = 3
|
||||
|
||||
|
||||
class ChannelCredentialCipher:
|
||||
"""Encrypts provider credentials before they are persisted."""
|
||||
@@ -128,36 +134,62 @@ class ChannelConnectionRepository:
|
||||
row.capabilities_json = dict(capabilities or {})
|
||||
row.metadata_json = dict(metadata or {})
|
||||
|
||||
async def _revoke_other_active_owners(session: AsyncSession) -> None:
|
||||
if status != "connected":
|
||||
return
|
||||
with session.no_autoflush:
|
||||
result = await session.execute(
|
||||
select(ChannelConnectionRow.id).where(
|
||||
ChannelConnectionRow.provider == provider,
|
||||
ChannelConnectionRow.external_account_id == external_account_id_value,
|
||||
ChannelConnectionRow.workspace_id == workspace_id_value,
|
||||
ChannelConnectionRow.owner_user_id != owner_user_id,
|
||||
ChannelConnectionRow.status != "revoked",
|
||||
)
|
||||
)
|
||||
transferred_ids = [row_id for row_id in result.scalars()]
|
||||
if not transferred_ids:
|
||||
return
|
||||
await session.execute(update(ChannelConnectionRow).where(ChannelConnectionRow.id.in_(transferred_ids)).values(status="revoked"))
|
||||
await session.execute(delete(ChannelCredentialRow).where(ChannelCredentialRow.connection_id.in_(transferred_ids)))
|
||||
|
||||
stmt = select(ChannelConnectionRow).where(
|
||||
ChannelConnectionRow.owner_user_id == owner_user_id,
|
||||
ChannelConnectionRow.provider == provider,
|
||||
ChannelConnectionRow.external_account_id == external_account_id_value,
|
||||
ChannelConnectionRow.workspace_id == workspace_id_value,
|
||||
)
|
||||
async with self.session_factory() as session:
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
row = ChannelConnectionRow(
|
||||
id=self._new_id(),
|
||||
owner_user_id=owner_user_id,
|
||||
provider=provider,
|
||||
external_account_id=external_account_id_value,
|
||||
workspace_id=workspace_id_value,
|
||||
)
|
||||
session.add(row)
|
||||
|
||||
_apply(row)
|
||||
try:
|
||||
await session.commit()
|
||||
except IntegrityError:
|
||||
# A concurrent writer inserted the same identity first; retry as
|
||||
# an update of that row.
|
||||
await session.rollback()
|
||||
row = (await session.execute(stmt)).scalar_one()
|
||||
_apply(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._connection_to_dict(row)
|
||||
async with self.session_factory() as session:
|
||||
last_error: IntegrityError | None = None
|
||||
for _ in range(_UPSERT_MAX_ATTEMPTS):
|
||||
try:
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
# Revoke any other owner's active row for this external identity
|
||||
# *before* our connected row is flushed, so the partial unique
|
||||
# index on active identities is satisfied at commit time.
|
||||
await _revoke_other_active_owners(session)
|
||||
if row is None:
|
||||
row = ChannelConnectionRow(
|
||||
id=self._new_id(),
|
||||
owner_user_id=owner_user_id,
|
||||
provider=provider,
|
||||
external_account_id=external_account_id_value,
|
||||
workspace_id=workspace_id_value,
|
||||
)
|
||||
session.add(row)
|
||||
_apply(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._connection_to_dict(row)
|
||||
except IntegrityError as exc:
|
||||
# A concurrent writer committed a conflicting row first (this
|
||||
# owner's identity, or the same active external identity). Roll
|
||||
# back and retry: the next pass re-reads the now-visible state,
|
||||
# revokes the newly-committed owner, and writes our row.
|
||||
last_error = exc
|
||||
await session.rollback()
|
||||
raise last_error # type: ignore[misc] # loop runs at least once
|
||||
|
||||
async def list_connections(self, owner_user_id: str) -> list[dict[str, Any]]:
|
||||
async with self.session_factory() as session:
|
||||
|
||||
@@ -5,7 +5,35 @@ from __future__ import annotations
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from app.channels.message_bus import InboundMessage, MessageBus
|
||||
from app.channels.base import Channel
|
||||
from app.channels.message_bus import InboundMessage, MessageBus, OutboundMessage
|
||||
|
||||
|
||||
class _StubChannel(Channel):
|
||||
"""Minimal concrete Channel used to exercise base-class helpers directly."""
|
||||
|
||||
async def start(self) -> None: # pragma: no cover - not exercised
|
||||
pass
|
||||
|
||||
async def stop(self) -> None: # pragma: no cover - not exercised
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None: # pragma: no cover - not exercised
|
||||
pass
|
||||
|
||||
|
||||
def test_pending_connect_code_extracts_code_when_connections_configured():
|
||||
channel = _StubChannel(name="stub", bus=MessageBus(), config={"connection_repo": object()})
|
||||
# A connect command yields its code; ordinary text does not.
|
||||
assert channel._pending_connect_code("/connect abc123") == "abc123"
|
||||
assert channel._pending_connect_code("hello world") is None
|
||||
|
||||
|
||||
def test_pending_connect_code_is_none_when_connections_disabled():
|
||||
# With no connection repo, binding is not configured and connect codes are
|
||||
# ignored so the message falls through to normal handling.
|
||||
channel = _StubChannel(name="stub", bus=MessageBus(), config={})
|
||||
assert channel._pending_connect_code("/connect abc123") is None
|
||||
|
||||
|
||||
async def _make_repo(tmp_path, name: str):
|
||||
|
||||
@@ -88,6 +88,119 @@ class TestChannelConnectionRepository:
|
||||
assert second["external_account_name"] == "Alice Telegram"
|
||||
assert len(await repo.list_connections("alice")) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_connection_transfers_external_identity_between_owners(self, repo):
|
||||
await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
external_account_id="U-shared",
|
||||
workspace_id="T1",
|
||||
status="connected",
|
||||
)
|
||||
|
||||
bob = await repo.upsert_connection(
|
||||
owner_user_id="bob",
|
||||
provider="slack",
|
||||
external_account_id="U-shared",
|
||||
workspace_id="T1",
|
||||
status="connected",
|
||||
)
|
||||
|
||||
alice_rows = await repo.list_connections("alice")
|
||||
resolved = await repo.find_connection_by_external_identity(
|
||||
provider="slack",
|
||||
external_account_id="U-shared",
|
||||
workspace_id="T1",
|
||||
)
|
||||
|
||||
assert alice_rows[0]["status"] == "revoked"
|
||||
assert bob["status"] == "connected"
|
||||
assert resolved is not None
|
||||
assert resolved["owner_user_id"] == "bob"
|
||||
assert resolved["id"] == bob["id"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_active_identity_unique_index_rejects_second_connected_owner(self, repo):
|
||||
# The single-active-owner invariant must be enforced by the database, not
|
||||
# only by the app-level revoke step (which can race under READ COMMITTED).
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
external_account_id="U-shared",
|
||||
workspace_id="T1",
|
||||
status="connected",
|
||||
)
|
||||
|
||||
with pytest.raises(IntegrityError):
|
||||
async with repo.session_factory() as session:
|
||||
session.add(
|
||||
ChannelConnectionRow(
|
||||
id="manual-duplicate-active",
|
||||
owner_user_id="bob",
|
||||
provider="slack",
|
||||
external_account_id="U-shared",
|
||||
workspace_id="T1",
|
||||
status="connected",
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_active_identity_unique_index_allows_revoked_rows(self, repo):
|
||||
# A revoked row must not occupy the active-identity slot, so a fresh
|
||||
# connected bind for the same identity is allowed afterwards.
|
||||
first = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
external_account_id="U-shared",
|
||||
workspace_id="T1",
|
||||
status="connected",
|
||||
)
|
||||
await repo.disconnect_connection(connection_id=first["id"], owner_user_id="alice")
|
||||
|
||||
second = await repo.upsert_connection(
|
||||
owner_user_id="bob",
|
||||
provider="slack",
|
||||
external_account_id="U-shared",
|
||||
workspace_id="T1",
|
||||
status="connected",
|
||||
)
|
||||
assert second["status"] == "connected"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_concurrent_upserts_keep_single_active_owner(self, repo):
|
||||
import asyncio
|
||||
|
||||
async def connect(owner: str):
|
||||
return await repo.upsert_connection(
|
||||
owner_user_id=owner,
|
||||
provider="slack",
|
||||
external_account_id="U-shared",
|
||||
workspace_id="T1",
|
||||
status="connected",
|
||||
)
|
||||
|
||||
await asyncio.gather(connect("alice"), connect("bob"))
|
||||
|
||||
async with repo.session_factory() as session:
|
||||
connected = (
|
||||
(
|
||||
await session.execute(
|
||||
select(ChannelConnectionRow).where(
|
||||
ChannelConnectionRow.provider == "slack",
|
||||
ChannelConnectionRow.external_account_id == "U-shared",
|
||||
ChannelConnectionRow.workspace_id == "T1",
|
||||
ChannelConnectionRow.status == "connected",
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
assert len(connected) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_credentials_are_encrypted_at_rest_and_decrypted_by_repository(self, repo):
|
||||
connection = await repo.upsert_connection(
|
||||
|
||||
@@ -4881,6 +4881,41 @@ class TestSlackAllowedUsers:
|
||||
assert inbound.chat_id == "C123"
|
||||
assert inbound.text == "hello from slack"
|
||||
|
||||
def test_connect_code_bypasses_allowed_users_filter(self):
|
||||
from app.channels.slack import SlackChannel
|
||||
|
||||
bus = MessageBus()
|
||||
bus.publish_inbound = AsyncMock()
|
||||
channel = SlackChannel(
|
||||
bus=bus,
|
||||
config={"allowed_users": ["U-allowed"], "connection_repo": object()},
|
||||
)
|
||||
channel._loop = MagicMock()
|
||||
channel._loop.is_running.return_value = True
|
||||
channel._bind_connection_from_connect_code = AsyncMock(return_value=True)
|
||||
channel._add_reaction = MagicMock()
|
||||
channel._send_running_reply = MagicMock()
|
||||
|
||||
event = {
|
||||
"user": "U-blocked",
|
||||
"text": "/connect slack-bind-code",
|
||||
"team": "T123",
|
||||
"channel": "C123",
|
||||
"ts": "1710000000.000100",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"app.channels.slack.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=self._submit_coro,
|
||||
) as submit:
|
||||
channel._handle_message_event(event)
|
||||
|
||||
channel._bind_connection_from_connect_code.assert_called_once()
|
||||
submit.assert_called_once()
|
||||
bus.publish_inbound.assert_not_awaited()
|
||||
channel._add_reaction.assert_not_called()
|
||||
channel._send_running_reply.assert_not_called()
|
||||
|
||||
def test_app_mention_strips_leading_bot_mention_before_command_detection(self):
|
||||
from app.channels.slack import SlackChannel
|
||||
|
||||
|
||||
@@ -435,6 +435,49 @@ class TestAllowedUsersFiltering:
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_non_allowed_user_message_content_not_logged(self, caplog):
|
||||
import logging
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
bus.publish_inbound = AsyncMock()
|
||||
channel = DingTalkChannel(bus, config={"allowed_users": ["user_001"]})
|
||||
channel._client_id = "test_key"
|
||||
channel._main_loop = asyncio.get_event_loop()
|
||||
channel._running = True
|
||||
|
||||
msg = _make_chatbot_message(sender_staff_id="user_blocked", text="secret blocked content")
|
||||
with caplog.at_level(logging.INFO, logger="app.channels.dingtalk"):
|
||||
channel._on_chatbot_message(msg)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
bus.publish_inbound.assert_not_awaited()
|
||||
# The parsed-message INFO log (with message content) must not fire for
|
||||
# a blocked sender — allowed_users still acts as a privacy/noise filter.
|
||||
assert "parsed message" not in caplog.text
|
||||
assert "secret blocked content" not in caplog.text
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_connect_code_bypasses_allowed_users_filter(self):
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
bus.publish_inbound = AsyncMock()
|
||||
channel = DingTalkChannel(bus, config={"allowed_users": ["user_001"], "connection_repo": object()})
|
||||
channel._client_id = "test_key"
|
||||
channel._main_loop = asyncio.get_event_loop()
|
||||
channel._running = True
|
||||
channel._bind_connection_from_connect_code = AsyncMock(return_value=True)
|
||||
|
||||
msg = _make_chatbot_message(sender_staff_id="user_blocked", text="/connect dingtalk-bind-code")
|
||||
channel._on_chatbot_message(msg)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
channel._bind_connection_from_connect_code.assert_awaited_once()
|
||||
bus.publish_inbound.assert_not_awaited()
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_empty_allowed_users_allows_all(self):
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
|
||||
@@ -71,6 +71,38 @@ async def test_start_with_deep_link_state_binds_telegram_chat(repo):
|
||||
assert "connected" in update.message.reply_text.await_args.args[0].lower()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_start_token_bypasses_allowed_users_filter(repo):
|
||||
# A newly allowlisted-but-unbound user must be able to bootstrap their first
|
||||
# bind via the deep-link start token even though their Telegram id is not yet
|
||||
# in allowed_users. The allowed_users gate must run after token handling.
|
||||
state = "telegram-bind-state"
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id="deerflow-user-1",
|
||||
provider="telegram",
|
||||
state=state,
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=5),
|
||||
)
|
||||
channel = TelegramChannel(
|
||||
bus=MessageBus(),
|
||||
config={
|
||||
"bot_token": "test-token",
|
||||
"connection_repo": repo,
|
||||
"allowed_users": [999], # newcomer (42) is not whitelisted
|
||||
},
|
||||
)
|
||||
update = _telegram_update(text=f"/start {state}", user_id=42)
|
||||
context = MagicMock()
|
||||
context.args = [state]
|
||||
|
||||
await channel._cmd_start(update, context)
|
||||
|
||||
connections = await repo.list_connections("deerflow-user-1")
|
||||
assert len(connections) == 1
|
||||
assert connections[0]["external_account_id"] == "42"
|
||||
assert "connected" in update.message.reply_text.await_args.args[0].lower()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_bound_telegram_message_publishes_connection_identity(repo):
|
||||
connection = await repo.upsert_connection(
|
||||
|
||||
@@ -7,6 +7,7 @@ import base64
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage
|
||||
|
||||
@@ -359,6 +360,66 @@ def test_allowed_users_filter_blocks_non_whitelisted_sender():
|
||||
_run(go())
|
||||
|
||||
|
||||
def test_connect_code_bypasses_allowed_users_filter(tmp_path: Path):
|
||||
from app.channels.wechat import WechatChannel
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
|
||||
async def go():
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'wechat.db'}", sqlite_dir=str(tmp_path))
|
||||
try:
|
||||
repo = ChannelConnectionRepository(
|
||||
get_session_factory(),
|
||||
cipher=ChannelCredentialCipher.from_key("wechat-secret"),
|
||||
)
|
||||
code = "wechat-bind-code"
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id="deerflow-user-1",
|
||||
provider="wechat",
|
||||
state=code,
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=5),
|
||||
)
|
||||
|
||||
bus = MessageBus()
|
||||
published = []
|
||||
|
||||
async def capture(msg):
|
||||
published.append(msg)
|
||||
|
||||
bus.publish_inbound = capture # type: ignore[method-assign]
|
||||
|
||||
# The newcomer ("blocked-user") is not in allowed_users yet, but a valid
|
||||
# /connect code must still bootstrap their first bind.
|
||||
channel = WechatChannel(
|
||||
bus=bus,
|
||||
config={"bot_token": "test-token", "allowed_users": ["allowed-user"], "connection_repo": repo},
|
||||
)
|
||||
channel._send_connection_reply = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
await channel._handle_update(
|
||||
{
|
||||
"message_type": 1,
|
||||
"from_user_id": "blocked-user",
|
||||
"context_token": "ctx-connect",
|
||||
"item_list": [{"type": 1, "text_item": {"text": f"/connect {code}"}}],
|
||||
}
|
||||
)
|
||||
|
||||
connections = await repo.list_connections("deerflow-user-1")
|
||||
assert len(connections) == 1
|
||||
assert connections[0]["provider"] == "wechat"
|
||||
assert connections[0]["external_account_id"] == "blocked-user"
|
||||
# The connect-code reply was sent and no normal inbound was published.
|
||||
channel._send_connection_reply.assert_awaited_once()
|
||||
assert published == []
|
||||
finally:
|
||||
await close_engine()
|
||||
|
||||
_run(go())
|
||||
|
||||
|
||||
def test_send_uses_cached_context_token(monkeypatch):
|
||||
from app.channels.wechat import WechatChannel
|
||||
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
import type { ChannelConnection, ChannelProviderId } from "./types";
|
||||
|
||||
export const CONNECT_POLL_INTERVAL_MS = 2000;
|
||||
// Fallback bind window used when the backend response omits or garbles
|
||||
// `expires_in`, so a non-finite value can never produce an unbounded poll loop.
|
||||
const DEFAULT_CONNECT_EXPIRES_S = 600;
|
||||
|
||||
export interface ConnectPollHandle {
|
||||
cancel: () => void;
|
||||
}
|
||||
|
||||
export interface ConnectPollOptions {
|
||||
provider: ChannelProviderId;
|
||||
expiresInSeconds: number;
|
||||
/** Fetch the latest connections — the single source of truth for "connected". */
|
||||
fetchConnections: () => Promise<ChannelConnection[]>;
|
||||
/** Invoked once when the provider's connection resolves to "connected". */
|
||||
onConnected: () => void;
|
||||
intervalMs?: number;
|
||||
now?: () => number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Poll the connections endpoint until the given provider reports `connected`
|
||||
* or the bind window elapses. Returns a handle whose `cancel()` stops the loop
|
||||
* (used to dedup repeated connects and to clean up on unmount).
|
||||
*
|
||||
* Only the connections endpoint is polled; `onConnected` lets the caller refresh
|
||||
* derived provider state exactly once when the bind lands, instead of fetching
|
||||
* both endpoints on every tick.
|
||||
*/
|
||||
export function startConnectionPoll(
|
||||
options: ConnectPollOptions,
|
||||
): ConnectPollHandle {
|
||||
const {
|
||||
provider,
|
||||
expiresInSeconds,
|
||||
fetchConnections,
|
||||
onConnected,
|
||||
intervalMs = CONNECT_POLL_INTERVAL_MS,
|
||||
now = Date.now,
|
||||
} = options;
|
||||
|
||||
const expires =
|
||||
Number.isFinite(expiresInSeconds) && expiresInSeconds > 0
|
||||
? expiresInSeconds
|
||||
: DEFAULT_CONNECT_EXPIRES_S;
|
||||
const deadline = now() + expires * 1000;
|
||||
|
||||
let timer: ReturnType<typeof setTimeout> | undefined;
|
||||
let cancelled = false;
|
||||
|
||||
const cancel = () => {
|
||||
cancelled = true;
|
||||
if (timer !== undefined) {
|
||||
clearTimeout(timer);
|
||||
timer = undefined;
|
||||
}
|
||||
};
|
||||
|
||||
const schedule = () => {
|
||||
timer = setTimeout(() => {
|
||||
timer = undefined;
|
||||
if (cancelled) {
|
||||
return;
|
||||
}
|
||||
void fetchConnections()
|
||||
.then((connections) => {
|
||||
if (cancelled) {
|
||||
return;
|
||||
}
|
||||
const connected = connections.some(
|
||||
(item) => item.provider === provider && item.status === "connected",
|
||||
);
|
||||
if (connected) {
|
||||
onConnected();
|
||||
return;
|
||||
}
|
||||
if (now() < deadline) {
|
||||
schedule();
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
if (!cancelled && now() < deadline) {
|
||||
schedule();
|
||||
}
|
||||
});
|
||||
}, intervalMs);
|
||||
};
|
||||
|
||||
schedule();
|
||||
return { cancel };
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
|
||||
import { useEffect, useRef } from "react";
|
||||
|
||||
import {
|
||||
configureChannelProvider,
|
||||
@@ -8,6 +9,7 @@ import {
|
||||
listChannelConnections,
|
||||
listChannelProviders,
|
||||
} from "./api";
|
||||
import { startConnectionPoll, type ConnectPollHandle } from "./connect-poll";
|
||||
import type { ChannelProviderId, ChannelRuntimeConfigValues } from "./types";
|
||||
|
||||
export const channelProviderQueryKey = ["channelProviders"] as const;
|
||||
@@ -36,14 +38,49 @@ export function useChannelConnections() {
|
||||
|
||||
export function useConnectChannelProvider() {
|
||||
const queryClient = useQueryClient();
|
||||
const pollersRef = useRef<Map<ChannelProviderId, ConnectPollHandle>>(
|
||||
new Map(),
|
||||
);
|
||||
|
||||
// Cancel any in-flight polls when the component using this hook unmounts.
|
||||
useEffect(() => {
|
||||
const pollers = pollersRef.current;
|
||||
return () => {
|
||||
pollers.forEach((handle) => handle.cancel());
|
||||
pollers.clear();
|
||||
};
|
||||
}, []);
|
||||
|
||||
return useMutation({
|
||||
mutationFn: (provider: ChannelProviderId) =>
|
||||
connectChannelProvider(provider),
|
||||
onSuccess: () => {
|
||||
onSuccess: (result, provider) => {
|
||||
void queryClient.invalidateQueries({ queryKey: channelProviderQueryKey });
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: channelConnectionsQueryKey,
|
||||
});
|
||||
|
||||
// Replace any existing poll for this provider so repeated Connect clicks
|
||||
// don't spawn parallel polling chains racing on the same query keys.
|
||||
pollersRef.current.get(provider)?.cancel();
|
||||
pollersRef.current.set(
|
||||
provider,
|
||||
startConnectionPoll({
|
||||
provider,
|
||||
expiresInSeconds: result.expires_in,
|
||||
fetchConnections: () =>
|
||||
queryClient.fetchQuery({
|
||||
queryKey: channelConnectionsQueryKey,
|
||||
queryFn: () => listChannelConnections(),
|
||||
}),
|
||||
onConnected: () => {
|
||||
// Refresh derived provider state exactly once when the bind lands.
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: channelProviderQueryKey,
|
||||
});
|
||||
},
|
||||
}),
|
||||
);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
import { afterEach, beforeEach, describe, expect, test, vi } from "vitest";
|
||||
|
||||
import { startConnectionPoll } from "@/core/channels/connect-poll";
|
||||
import type { ChannelConnection } from "@/core/channels/types";
|
||||
|
||||
function connection(provider: string, status: string): ChannelConnection {
|
||||
return {
|
||||
id: `${provider}-1`,
|
||||
provider,
|
||||
status,
|
||||
scopes: [],
|
||||
metadata: {},
|
||||
};
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
describe("startConnectionPoll", () => {
|
||||
test("polls connections until the provider is connected, then resolves once", async () => {
|
||||
const responses: ChannelConnection[][] = [
|
||||
[connection("telegram", "pending")],
|
||||
[connection("telegram", "connected")],
|
||||
];
|
||||
const fetchConnections = vi.fn(async () => responses.shift() ?? []);
|
||||
const onConnected = vi.fn();
|
||||
|
||||
startConnectionPoll({
|
||||
provider: "telegram",
|
||||
expiresInSeconds: 600,
|
||||
fetchConnections,
|
||||
onConnected,
|
||||
intervalMs: 1000,
|
||||
});
|
||||
|
||||
await vi.advanceTimersByTimeAsync(1000);
|
||||
expect(fetchConnections).toHaveBeenCalledTimes(1);
|
||||
expect(onConnected).not.toHaveBeenCalled();
|
||||
|
||||
await vi.advanceTimersByTimeAsync(1000);
|
||||
expect(fetchConnections).toHaveBeenCalledTimes(2);
|
||||
expect(onConnected).toHaveBeenCalledTimes(1);
|
||||
|
||||
// No further polling after the connection resolves.
|
||||
await vi.advanceTimersByTimeAsync(5000);
|
||||
expect(fetchConnections).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
test("cancel() stops scheduled polling and fires no further fetches", async () => {
|
||||
const fetchConnections = vi.fn(async () => [
|
||||
connection("telegram", "pending"),
|
||||
]);
|
||||
const handle = startConnectionPoll({
|
||||
provider: "telegram",
|
||||
expiresInSeconds: 600,
|
||||
fetchConnections,
|
||||
onConnected: vi.fn(),
|
||||
intervalMs: 1000,
|
||||
});
|
||||
|
||||
await vi.advanceTimersByTimeAsync(1000);
|
||||
expect(fetchConnections).toHaveBeenCalledTimes(1);
|
||||
|
||||
handle.cancel();
|
||||
await vi.advanceTimersByTimeAsync(10000);
|
||||
expect(fetchConnections).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
test("a non-finite expires_in falls back to a finite deadline and terminates", async () => {
|
||||
const fetchConnections = vi.fn(async () => [
|
||||
connection("telegram", "pending"),
|
||||
]);
|
||||
let nowValue = 0;
|
||||
startConnectionPoll({
|
||||
provider: "telegram",
|
||||
expiresInSeconds: Number.NaN,
|
||||
fetchConnections,
|
||||
onConnected: vi.fn(),
|
||||
intervalMs: 1000,
|
||||
now: () => nowValue,
|
||||
});
|
||||
|
||||
nowValue = 1;
|
||||
await vi.advanceTimersByTimeAsync(1000);
|
||||
expect(fetchConnections).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Jump past the fallback expiry window: the loop must stop instead of
|
||||
// running forever (Date.now() >= NaN would otherwise never be true).
|
||||
nowValue = 10_000_000;
|
||||
await vi.advanceTimersByTimeAsync(1000);
|
||||
expect(fetchConnections).toHaveBeenCalledTimes(2);
|
||||
|
||||
await vi.advanceTimersByTimeAsync(10000);
|
||||
expect(fetchConnections).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user