fix(channels): make channel connect flow deterministic (#3582)

* fix(channels): make channel connect flow deterministic

* make format

* fix(channels): apply connect-code before allowed_users on telegram and wechat

The bind-bootstrap reorder shipped for slack/dingtalk only. Telegram and
WeChat still gate _check_user/allowed_users before connect-code dispatch, so
a newly allowlisted-but-unbound user is silently rejected when binding via the
browser deep-link / connect-code flow — the same deadlock the PR fixes.

- telegram: consume the /start deep-link token before the allowed_users gate.
- wechat: handle the /connect code before the allowed_users gate, and defer
  inbound file extraction + context-token tracking past the gate so blocked
  senders no longer trigger CDN downloads or token bookkeeping.

Adds regression tests for both adapters mirroring the slack/dingtalk coverage.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* fix(channels): enforce single-active-owner invariant at the DB layer

_revoke_other_active_owners did a SELECT-then-UPDATE in app code with no row
lock or constraint covering active rows. Under READ COMMITTED, two concurrent
connect-code consumes for the same (provider, external_account_id, workspace_id)
from different owners could each observe "no other active owner" and both commit
a connected row, leaving find_connection_by_external_identity nondeterministic.

- Add a partial unique index on (provider, external_account_id, workspace_id)
  WHERE status != 'revoked' (portable to SQLite >= 3.8.0 and PostgreSQL) so the
  database guarantees at most one non-revoked row per external identity.
- Reorder upsert_connection to revoke other owners' active rows before the new
  connected row is flushed (so the index is satisfied at commit), wrapped in a
  bounded rollback-and-retry loop. A losing concurrent writer now retries
  against the now-visible state instead of committing a duplicate.

Adds DB-constraint, revoked-slot-reuse, and concurrent-upsert regression tests.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* fix(channels): harden connect-status polling primitive

pollChannelConnectionUntilResolved was a free-floating recursive setTimeout
started from onSuccess with no cancellation, no per-provider dedup, a redundant
second endpoint per tick, and an unbounded loop on a non-finite expires_in.

- Extract a framework-agnostic, cancellable poller (connect-poll.ts) that polls
  only listChannelConnections() and invalidates the providers query once when the
  bind resolves, instead of fetching both endpoints every tick.
- Guard expires_in with a finite check + default window so undefined/NaN can no
  longer produce a poll loop that runs until the page closes.
- Track one active poll handle per provider in useConnectChannelProvider via a
  ref Map: a new connect cancels the prior poll for that provider, and a useEffect
  cleanup cancels all polls on unmount.

Adds unit tests for resolve-and-stop, cancellation, and non-finite-expiry.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* fix(channels): stop leaking blocked-sender content in DingTalk INFO log; document bind semantics

Moving the allowed_users gate past _extract_text meant the parsed-message INFO
log (text=%r, first 100 chars) fired for senders that allowed_users would have
rejected, defeating the filter's noise/privacy role. Move that log to after the
allowed_users gate so blocked senders' message text never reaches INFO logs.

Also document the two operator-relevant semantic changes in backend/CLAUDE.md:
connect-code dispatch runs before allowed_users (so allowed_users is no longer a
bind-time defense; the model relies on code confidentiality + 600s TTL + one-time
consumption), and the single-active-owner-per-external-identity transfer semantics
now backed by the partial unique index.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* docs(channels): note connect-code-vs-allowlist and ownership transfer in operator guide

Mirror the backend/CLAUDE.md notes in the operator-facing IM_CHANNEL_CONNECTIONS.md:
connect codes are consumed before allowed_users (so a not-yet-allowlisted user can
still complete a first bind, and allowed_users is not a bind-time defense), and an
external identity has at most one active owner with last-bind-wins transfer enforced
at the DB layer.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* refactor(channels): lift connect-code dispatch into Channel base class

Each adapter duplicated the ordering-sensitive boilerplate of extracting a
/connect code and guarding on the connection repo before its allowed_users gate.
The duplication is what let telegram/wechat drift and keep the gate ahead of the
bind. Centralize it:

- Move `_connection_repo` onto Channel.__init__ (removing 7 duplicate assignments).
- Add Channel._pending_connect_code(text), which guards on the repo and extracts
  the code, documenting that adapters MUST consult it before authorization so a
  browser-initiated bind can bootstrap a not-yet-authorized identity.
- Route slack, discord, feishu, dingtalk, wechat, and wecom through the helper.
  This also fixes a latent inconsistency where slack dispatched a bind even when
  no connection repo was configured.

Pure refactor — the full channel suite stays green; adds a direct unit test for
the base helper's contract.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* make format

* fix(channels): redact DingTalk parsed-message INFO log content

Log text_len instead of the first 100 chars of message text, so message
content never reaches INFO logs (the after-gate move already keeps blocked
senders out entirely). This takes over the redaction from #3584 so only this
PR touches dingtalk.py, letting the two PRs merge in any order conflict-free.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
Nan Gao
2026-06-18 04:15:31 +02:00
committed by GitHub
parent 8c0830aea1
commit 68ba4198b8
21 changed files with 695 additions and 80 deletions
+2
View File
@@ -420,6 +420,8 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, Discord, DingTalk
- Provider-level `connection_status` reflects the user's newest connection row. With no binding it is `not_connected`, except in auth-disabled local mode where a configured running channel reports `connected` because all channel messages already route to the default user.
- Slack replies use the configured operator bot token from `channels.slack` unless per-connection credentials are present; unreadable or corrupt stored credentials are treated as unavailable.
- Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom workers resolve incoming platform identities to connection records before reaching `ChannelManager`.
- **Connect-code ordering vs `allowed_users`**: inbound workers consume a valid `/connect <code>` (or Telegram `/start <code>`) **before** applying the `allowed_users` filter, so a newly allowlisted-but-unbound user can bootstrap their first bind via the browser flow. Consequence: `allowed_users` is **not** a bind-time defense — any sender who possesses a valid code can consume it (not only allowlisted users). The bind security model rests on the code's confidentiality: `secrets.token_urlsafe(16)`, 600 s TTL, one-time `consume_oauth_state`, and codes surfaced only in the initiating browser (never echoed to chat). `allowed_users` still gates ordinary (non-bind) messages.
- **Single-active-owner transfer semantics**: an external identity is keyed by `(provider, external_account_id, workspace_id)`. The latest successful bind wins — `upsert_connection` revokes other owners' active rows for the same identity (ownership transfer). This invariant is enforced at the DB layer by the partial unique index `uq_channel_connection_active_identity` (`WHERE status != 'revoked'`), so concurrent connects from different owners cannot both end `connected`; the losing writer retries against the now-visible state. `find_connection_by_external_identity` therefore resolves deterministically.
- See `backend/docs/IM_CHANNEL_CONNECTIONS.md` for provider setup and operational notes.
+15
View File
@@ -9,6 +9,7 @@ from collections.abc import Awaitable, Callable
from concurrent.futures import CancelledError as FutureCancelledError
from typing import Any, TypeVar
from app.channels.commands import extract_connect_code
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
@@ -31,6 +32,7 @@ class Channel(ABC):
self.bus = bus
self.config = config
self._running = False
self._connection_repo: Any = config.get("connection_repo")
@property
def is_running(self) -> bool:
@@ -117,6 +119,19 @@ class Channel(ABC):
if exc:
logger.error("[%s] %s failed for msg_id=%s: %s", self.name, name, msg_id, exc)
def _pending_connect_code(self, text: str) -> str | None:
"""Return the one-time bind code if *text* is a ``/connect <code>`` command
and channel connections are configured, else ``None``.
Adapters MUST consult this **before** applying their ``allowed_users`` /
``_check_user`` gate, so a browser-initiated bind can bootstrap an external
identity that the platform bot has never seen and is therefore not yet
authorized. (Telegram uses its deep-link ``/start <token>`` flow instead.)
"""
if self._connection_repo is None:
return None
return extract_connect_code(text)
def _make_inbound(
self,
chat_id: str,
+19 -17
View File
@@ -14,7 +14,7 @@ from typing import Any
import httpx
from app.channels.base import Channel
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.commands import is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
@@ -137,7 +137,6 @@ class DingTalkChannel(Channel):
self._incoming_messages: dict[str, Any] = {}
self._incoming_messages_lock = threading.Lock()
self._card_repliers: dict[str, Any] = {}
self._connection_repo = config.get("connection_repo")
@property
def supports_streaming(self) -> bool:
@@ -366,26 +365,13 @@ class DingTalkChannel(Channel):
msg_id = message.message_id or ""
sender_nick = message.sender_nick or ""
if self._allowed_users and sender_staff_id not in self._allowed_users:
logger.debug("[DingTalk] ignoring message from non-allowed user: %s", sender_staff_id)
return
text = self._extract_text(message)
if not text:
logger.info("[DingTalk] empty text, ignoring message")
return
logger.info(
"[DingTalk] parsed message: conv_type=%s, msg_id=%s, sender=%s(%s), text=%r",
conversation_type,
msg_id,
sender_staff_id,
sender_nick,
text[:100],
)
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
connect_code = self._pending_connect_code(text)
if connect_code:
if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(
self._bind_connection_from_connect_code(
@@ -402,6 +388,22 @@ class DingTalkChannel(Channel):
logger.warning("[DingTalk] main loop not running, cannot bind channel connection")
return
if self._allowed_users and sender_staff_id not in self._allowed_users:
logger.debug("[DingTalk] ignoring message from non-allowed user: %s", sender_staff_id)
return
# Log only metadata (length, not content) so message text never reaches
# INFO logs, and only after the allowed_users gate so blocked senders are
# not logged at all.
logger.info(
"[DingTalk] parsed message: conv_type=%s, msg_id=%s, sender=%s(%s), text_len=%d",
conversation_type,
msg_id,
sender_staff_id,
sender_nick,
len(text or ""),
)
if _is_dingtalk_command(text):
msg_type = InboundMessageType.COMMAND
else:
+2 -3
View File
@@ -10,7 +10,7 @@ from pathlib import Path
from typing import Any
from app.channels.base import Channel
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.commands import is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
@@ -71,7 +71,6 @@ class DiscordChannel(Channel):
self._discord_loop: asyncio.AbstractEventLoop | None = None
self._main_loop: asyncio.AbstractEventLoop | None = None
self._discord_module = None
self._connection_repo = config.get("connection_repo")
async def start(self) -> None:
if self._running:
@@ -293,7 +292,7 @@ class DiscordChannel(Channel):
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
# Don't return early if text is empty — still process the mention (e.g., create thread)
connect_code = extract_connect_code(text)
connect_code = self._pending_connect_code(text)
if connect_code and await self._bind_connection_from_connect_code(message, connect_code):
return
+3 -4
View File
@@ -11,7 +11,7 @@ import time
from typing import Any, Literal
from app.channels.base import Channel
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.commands import is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import (
PENDING_CLARIFICATION_METADATA_KEY,
@@ -72,7 +72,6 @@ class FeishuChannel(Channel):
self._CreateImageRequestBody = None
self._GetMessageResourceRequest = None
self._thread_lock = threading.Lock()
self._connection_repo = config.get("connection_repo")
@staticmethod
def _non_empty_str(value: Any) -> str | None:
@@ -851,8 +850,8 @@ class FeishuChannel(Channel):
logger.info("[Feishu] empty text, ignoring message")
return
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
connect_code = self._pending_connect_code(text)
if connect_code:
if self._main_loop and self._main_loop.is_running():
fut = asyncio.run_coroutine_threadsafe(
self._bind_connection_from_connect_code(
+8 -8
View File
@@ -9,7 +9,7 @@ from typing import Any
from markdown_to_mrkdwn import SlackMarkdownConverter
from app.channels.base import Channel
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.commands import is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
@@ -65,7 +65,6 @@ class SlackChannel(Channel):
self._web_client = None
self._loop: asyncio.AbstractEventLoop | None = None
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
self._connection_repo = config.get("connection_repo")
self._web_client_factory = config.get("web_client_factory")
self._connection_web_clients: dict[str, tuple[str, Any]] = {}
configured_bot_user_id = config.get("bot_user_id")
@@ -295,18 +294,13 @@ class SlackChannel(Channel):
user_id = event.get("user", "")
# Check allowed users
if self._allowed_users and user_id not in self._allowed_users:
logger.debug("Ignoring message from non-allowed user: %s", user_id)
return
text = event.get("text", "").strip()
if event.get("type") == "app_mention":
text = _strip_leading_slack_bot_mention(text, self._bot_user_id)
if not text:
return
connect_code = extract_connect_code(text)
connect_code = self._pending_connect_code(text)
if connect_code:
if self._loop and self._loop.is_running():
asyncio.run_coroutine_threadsafe(
@@ -319,6 +313,12 @@ class SlackChannel(Channel):
)
return
# Check allowed users after connect-code handling so browser-initiated
# binding can bootstrap a new external identity.
if self._allowed_users and user_id not in self._allowed_users:
logger.debug("Ignoring message from non-allowed user: %s", user_id)
return
channel_id = event.get("channel", "")
thread_ts = event.get("thread_ts") or event.get("ts", "")
+4 -3
View File
@@ -52,7 +52,6 @@ class TelegramChannel(Channel):
# stream_key ("chat_id:thread_ts") -> state of the in-flight streamed
# bot message being edited in place: {"message_id", "last_edit_at", "last_text"}
self._stream_messages: dict[str, dict[str, Any]] = {}
self._connection_repo = config.get("connection_repo")
@property
def supports_streaming(self) -> bool:
@@ -463,13 +462,15 @@ class TelegramChannel(Channel):
async def _cmd_start(self, update, context) -> None:
"""Handle /start command."""
if not self._check_user(update.effective_user.id):
return
args = getattr(context, "args", []) if context is not None else []
if args:
# Handle the deep-link bind token before applying allowed_users so a
# browser-initiated bind can bootstrap a new external identity.
handled = await self._bind_connection_from_start_token(update, str(args[0]))
if handled:
return
if not self._check_user(update.effective_user.id):
return
await update.message.reply_text("Welcome to DeerFlow! Send me a message to start a conversation.\nType /help for available commands.")
async def _process_incoming_with_reply(self, chat_id: str, msg_id: int, inbound: InboundMessage) -> None:
+20 -15
View File
@@ -22,7 +22,7 @@ from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from app.channels.base import Channel
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.commands import is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
@@ -254,7 +254,6 @@ class WechatChannel(Channel):
self._state_dir = self._resolve_state_dir(config.get("state_dir"))
self._cursor_path = self._state_dir / "wechat-getupdates.json" if self._state_dir else None
self._auth_path = self._state_dir / "wechat-auth.json" if self._state_dir else None
self._connection_repo = config.get("connection_repo")
self._load_state()
async def start(self) -> None:
@@ -591,24 +590,16 @@ class WechatChannel(Channel):
return
chat_id = str(raw_message.get("from_user_id") or raw_message.get("ilink_user_id") or "").strip()
if not chat_id or not self._check_user(chat_id):
if not chat_id:
return
text = self._extract_text(raw_message)
files = await self._extract_inbound_files(raw_message)
if not text and not files:
return
context_token = str(raw_message.get("context_token") or "").strip()
thread_ts = context_token or str(raw_message.get("client_id") or raw_message.get("msg_id") or "").strip() or None
if context_token:
self._context_tokens_by_chat[chat_id] = context_token
if thread_ts:
self._context_tokens_by_thread[thread_ts] = context_token
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
# Handle the connect code before applying allowed_users so a browser-initiated
# bind can bootstrap an external identity that is not yet whitelisted.
connect_code = self._pending_connect_code(text)
if connect_code:
handled = await self._bind_connection_from_connect_code(
chat_id=chat_id,
context_token=context_token,
@@ -617,6 +608,20 @@ class WechatChannel(Channel):
if handled:
return
if not self._check_user(chat_id):
return
files = await self._extract_inbound_files(raw_message)
if not text and not files:
return
thread_ts = context_token or str(raw_message.get("client_id") or raw_message.get("msg_id") or "").strip() or None
if context_token:
self._context_tokens_by_chat[chat_id] = context_token
if thread_ts:
self._context_tokens_by_thread[thread_ts] = context_token
inbound = self._make_inbound(
chat_id=chat_id,
user_id=chat_id,
+3 -4
View File
@@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable
from typing import Any, cast
from app.channels.base import Channel
from app.channels.commands import extract_connect_code, is_known_channel_command
from app.channels.commands import is_known_channel_command
from app.channels.connection_identity import attach_connection_identity
from app.channels.message_bus import (
InboundMessage,
@@ -31,7 +31,6 @@ class WeComChannel(Channel):
self._ws_frames: dict[str, dict[str, Any]] = {}
self._ws_stream_ids: dict[str, str] = {}
self._working_message = "Working on it..."
self._connection_repo = config.get("connection_repo")
@property
def supports_streaming(self) -> bool:
@@ -295,8 +294,8 @@ class WeComChannel(Channel):
user_id = (body.get("from") or {}).get("userid")
connect_code = extract_connect_code(text)
if connect_code and self._connection_repo is not None:
connect_code = self._pending_connect_code(text)
if connect_code:
handled = await self._bind_connection_from_connect_code(
frame=frame,
user_id=str(user_id or ""),
+4
View File
@@ -111,6 +111,8 @@ Feishu/Lark, DingTalk, WeChat, and WeCom:
Codes use 128 bits of randomness, expire after 10 minutes, and are single-use.
For providers with an `allowed_users` allowlist (Telegram, Slack, DingTalk, WeChat, …), a valid `/connect <code>` (or Telegram `/start <code>`) is consumed **before** the allowlist is checked. This is intentional: a user who is not yet on the allowlist — and whose platform identity the bot has therefore never seen — can still complete their first browser-initiated bind. After binding, `allowed_users` continues to gate ordinary (non-bind) messages as before.
## Runtime Model
Connection records live in SQL tables under `deerflow.persistence.channel_connections`:
@@ -126,6 +128,8 @@ Incoming messages that resolve to a connection carry `connection_id`, `owner_use
- Browser APIs remain authenticated and CSRF-protected.
- Connect codes are 128-bit random, short-lived, and single-use.
- `allowed_users` is **not** a bind-time defense. Because connect codes are processed before the allowlist (see Connect Flow), anyone who possesses a valid code can consume it — not only allowlisted users. Bind security therefore rests entirely on the code's confidentiality: it is 128-bit random, expires after 10 minutes, is single-use, and is shown only in the initiating user's browser (never echoed back to chat). Treat connect codes like one-time passwords and do not forward them.
- An external identity — `(provider, external account, workspace/team/guild)` — has at most one active owner. The most recent successful bind wins: connecting an identity that another DeerFlow user already holds transfers ownership and revokes the previous owner's binding (and its stored credentials). This is enforced at the database layer, so two users racing to bind the same identity cannot both end up connected.
- Provider bot tokens remain in `channels.*` and are never returned to the browser.
- Stored per-connection credentials are encrypted. If stored credential material cannot be decrypted, DeerFlow treats it as unavailable instead of using corrupt secrets.
- This implementation does not add public provider callback or webhook routes.
@@ -4,7 +4,7 @@ from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import JSON, DateTime, ForeignKey, Index, Integer, String, Text, UniqueConstraint
from sqlalchemy import JSON, DateTime, ForeignKey, Index, Integer, String, Text, UniqueConstraint, text
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
@@ -46,6 +46,20 @@ class ChannelConnectionRow(Base):
name="uq_channel_connection_owner_provider_identity",
),
Index("idx_channel_connections_event_lookup", "provider", "workspace_id", "bot_user_id"),
# Enforce the single-active-owner invariant at the database layer: at most
# one non-revoked row may exist per external identity. This makes ownership
# transfer race-safe (concurrent connects from different owners can no
# longer both commit a connected row). Partial unique indexes are
# supported by both SQLite (>= 3.8.0) and PostgreSQL.
Index(
"uq_channel_connection_active_identity",
"provider",
"external_account_id",
"workspace_id",
unique=True,
sqlite_where=text("status != 'revoked'"),
postgresql_where=text("status != 'revoked'"),
),
)
@@ -25,6 +25,12 @@ from deerflow.utils.time import coerce_iso
logger = logging.getLogger(__name__)
# Bounded retries for upsert_connection when a concurrent writer commits a
# conflicting row first (same owner identity, or the same active external
# identity guarded by the partial unique index). Each retry re-reads the
# now-visible state, so a small bound converges under realistic contention.
_UPSERT_MAX_ATTEMPTS = 3
class ChannelCredentialCipher:
"""Encrypts provider credentials before they are persisted."""
@@ -128,14 +134,41 @@ class ChannelConnectionRepository:
row.capabilities_json = dict(capabilities or {})
row.metadata_json = dict(metadata or {})
async def _revoke_other_active_owners(session: AsyncSession) -> None:
if status != "connected":
return
with session.no_autoflush:
result = await session.execute(
select(ChannelConnectionRow.id).where(
ChannelConnectionRow.provider == provider,
ChannelConnectionRow.external_account_id == external_account_id_value,
ChannelConnectionRow.workspace_id == workspace_id_value,
ChannelConnectionRow.owner_user_id != owner_user_id,
ChannelConnectionRow.status != "revoked",
)
)
transferred_ids = [row_id for row_id in result.scalars()]
if not transferred_ids:
return
await session.execute(update(ChannelConnectionRow).where(ChannelConnectionRow.id.in_(transferred_ids)).values(status="revoked"))
await session.execute(delete(ChannelCredentialRow).where(ChannelCredentialRow.connection_id.in_(transferred_ids)))
stmt = select(ChannelConnectionRow).where(
ChannelConnectionRow.owner_user_id == owner_user_id,
ChannelConnectionRow.provider == provider,
ChannelConnectionRow.external_account_id == external_account_id_value,
ChannelConnectionRow.workspace_id == workspace_id_value,
)
async with self.session_factory() as session:
last_error: IntegrityError | None = None
for _ in range(_UPSERT_MAX_ATTEMPTS):
try:
row = (await session.execute(stmt)).scalar_one_or_none()
# Revoke any other owner's active row for this external identity
# *before* our connected row is flushed, so the partial unique
# index on active identities is satisfied at commit time.
await _revoke_other_active_owners(session)
if row is None:
row = ChannelConnectionRow(
id=self._new_id(),
@@ -145,19 +178,18 @@ class ChannelConnectionRepository:
workspace_id=workspace_id_value,
)
session.add(row)
_apply(row)
try:
await session.commit()
except IntegrityError:
# A concurrent writer inserted the same identity first; retry as
# an update of that row.
await session.rollback()
row = (await session.execute(stmt)).scalar_one()
_apply(row)
await session.commit()
await session.refresh(row)
return self._connection_to_dict(row)
except IntegrityError as exc:
# A concurrent writer committed a conflicting row first (this
# owner's identity, or the same active external identity). Roll
# back and retry: the next pass re-reads the now-visible state,
# revokes the newly-committed owner, and writes our row.
last_error = exc
await session.rollback()
raise last_error # type: ignore[misc] # loop runs at least once
async def list_connections(self, owner_user_id: str) -> list[dict[str, Any]]:
async with self.session_factory() as session:
@@ -5,7 +5,35 @@ from __future__ import annotations
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
from app.channels.message_bus import InboundMessage, MessageBus
from app.channels.base import Channel
from app.channels.message_bus import InboundMessage, MessageBus, OutboundMessage
class _StubChannel(Channel):
"""Minimal concrete Channel used to exercise base-class helpers directly."""
async def start(self) -> None: # pragma: no cover - not exercised
pass
async def stop(self) -> None: # pragma: no cover - not exercised
pass
async def send(self, msg: OutboundMessage) -> None: # pragma: no cover - not exercised
pass
def test_pending_connect_code_extracts_code_when_connections_configured():
channel = _StubChannel(name="stub", bus=MessageBus(), config={"connection_repo": object()})
# A connect command yields its code; ordinary text does not.
assert channel._pending_connect_code("/connect abc123") == "abc123"
assert channel._pending_connect_code("hello world") is None
def test_pending_connect_code_is_none_when_connections_disabled():
# With no connection repo, binding is not configured and connect codes are
# ignored so the message falls through to normal handling.
channel = _StubChannel(name="stub", bus=MessageBus(), config={})
assert channel._pending_connect_code("/connect abc123") is None
async def _make_repo(tmp_path, name: str):
@@ -88,6 +88,119 @@ class TestChannelConnectionRepository:
assert second["external_account_name"] == "Alice Telegram"
assert len(await repo.list_connections("alice")) == 1
@pytest.mark.anyio
async def test_upsert_connection_transfers_external_identity_between_owners(self, repo):
await repo.upsert_connection(
owner_user_id="alice",
provider="slack",
external_account_id="U-shared",
workspace_id="T1",
status="connected",
)
bob = await repo.upsert_connection(
owner_user_id="bob",
provider="slack",
external_account_id="U-shared",
workspace_id="T1",
status="connected",
)
alice_rows = await repo.list_connections("alice")
resolved = await repo.find_connection_by_external_identity(
provider="slack",
external_account_id="U-shared",
workspace_id="T1",
)
assert alice_rows[0]["status"] == "revoked"
assert bob["status"] == "connected"
assert resolved is not None
assert resolved["owner_user_id"] == "bob"
assert resolved["id"] == bob["id"]
@pytest.mark.anyio
async def test_active_identity_unique_index_rejects_second_connected_owner(self, repo):
# The single-active-owner invariant must be enforced by the database, not
# only by the app-level revoke step (which can race under READ COMMITTED).
from sqlalchemy.exc import IntegrityError
await repo.upsert_connection(
owner_user_id="alice",
provider="slack",
external_account_id="U-shared",
workspace_id="T1",
status="connected",
)
with pytest.raises(IntegrityError):
async with repo.session_factory() as session:
session.add(
ChannelConnectionRow(
id="manual-duplicate-active",
owner_user_id="bob",
provider="slack",
external_account_id="U-shared",
workspace_id="T1",
status="connected",
)
)
await session.commit()
@pytest.mark.anyio
async def test_active_identity_unique_index_allows_revoked_rows(self, repo):
# A revoked row must not occupy the active-identity slot, so a fresh
# connected bind for the same identity is allowed afterwards.
first = await repo.upsert_connection(
owner_user_id="alice",
provider="slack",
external_account_id="U-shared",
workspace_id="T1",
status="connected",
)
await repo.disconnect_connection(connection_id=first["id"], owner_user_id="alice")
second = await repo.upsert_connection(
owner_user_id="bob",
provider="slack",
external_account_id="U-shared",
workspace_id="T1",
status="connected",
)
assert second["status"] == "connected"
@pytest.mark.anyio
async def test_concurrent_upserts_keep_single_active_owner(self, repo):
import asyncio
async def connect(owner: str):
return await repo.upsert_connection(
owner_user_id=owner,
provider="slack",
external_account_id="U-shared",
workspace_id="T1",
status="connected",
)
await asyncio.gather(connect("alice"), connect("bob"))
async with repo.session_factory() as session:
connected = (
(
await session.execute(
select(ChannelConnectionRow).where(
ChannelConnectionRow.provider == "slack",
ChannelConnectionRow.external_account_id == "U-shared",
ChannelConnectionRow.workspace_id == "T1",
ChannelConnectionRow.status == "connected",
)
)
)
.scalars()
.all()
)
assert len(connected) == 1
@pytest.mark.anyio
async def test_credentials_are_encrypted_at_rest_and_decrypted_by_repository(self, repo):
connection = await repo.upsert_connection(
+35
View File
@@ -4881,6 +4881,41 @@ class TestSlackAllowedUsers:
assert inbound.chat_id == "C123"
assert inbound.text == "hello from slack"
def test_connect_code_bypasses_allowed_users_filter(self):
from app.channels.slack import SlackChannel
bus = MessageBus()
bus.publish_inbound = AsyncMock()
channel = SlackChannel(
bus=bus,
config={"allowed_users": ["U-allowed"], "connection_repo": object()},
)
channel._loop = MagicMock()
channel._loop.is_running.return_value = True
channel._bind_connection_from_connect_code = AsyncMock(return_value=True)
channel._add_reaction = MagicMock()
channel._send_running_reply = MagicMock()
event = {
"user": "U-blocked",
"text": "/connect slack-bind-code",
"team": "T123",
"channel": "C123",
"ts": "1710000000.000100",
}
with patch(
"app.channels.slack.asyncio.run_coroutine_threadsafe",
side_effect=self._submit_coro,
) as submit:
channel._handle_message_event(event)
channel._bind_connection_from_connect_code.assert_called_once()
submit.assert_called_once()
bus.publish_inbound.assert_not_awaited()
channel._add_reaction.assert_not_called()
channel._send_running_reply.assert_not_called()
def test_app_mention_strips_leading_bot_mention_before_command_detection(self):
from app.channels.slack import SlackChannel
+43
View File
@@ -435,6 +435,49 @@ class TestAllowedUsersFiltering:
_run(go())
def test_non_allowed_user_message_content_not_logged(self, caplog):
import logging
async def go():
bus = MessageBus()
bus.publish_inbound = AsyncMock()
channel = DingTalkChannel(bus, config={"allowed_users": ["user_001"]})
channel._client_id = "test_key"
channel._main_loop = asyncio.get_event_loop()
channel._running = True
msg = _make_chatbot_message(sender_staff_id="user_blocked", text="secret blocked content")
with caplog.at_level(logging.INFO, logger="app.channels.dingtalk"):
channel._on_chatbot_message(msg)
await asyncio.sleep(0.1)
bus.publish_inbound.assert_not_awaited()
# The parsed-message INFO log (with message content) must not fire for
# a blocked sender — allowed_users still acts as a privacy/noise filter.
assert "parsed message" not in caplog.text
assert "secret blocked content" not in caplog.text
_run(go())
def test_connect_code_bypasses_allowed_users_filter(self):
async def go():
bus = MessageBus()
bus.publish_inbound = AsyncMock()
channel = DingTalkChannel(bus, config={"allowed_users": ["user_001"], "connection_repo": object()})
channel._client_id = "test_key"
channel._main_loop = asyncio.get_event_loop()
channel._running = True
channel._bind_connection_from_connect_code = AsyncMock(return_value=True)
msg = _make_chatbot_message(sender_staff_id="user_blocked", text="/connect dingtalk-bind-code")
channel._on_chatbot_message(msg)
await asyncio.sleep(0.1)
channel._bind_connection_from_connect_code.assert_awaited_once()
bus.publish_inbound.assert_not_awaited()
_run(go())
def test_empty_allowed_users_allows_all(self):
async def go():
bus = MessageBus()
@@ -71,6 +71,38 @@ async def test_start_with_deep_link_state_binds_telegram_chat(repo):
assert "connected" in update.message.reply_text.await_args.args[0].lower()
@pytest.mark.anyio
async def test_start_token_bypasses_allowed_users_filter(repo):
# A newly allowlisted-but-unbound user must be able to bootstrap their first
# bind via the deep-link start token even though their Telegram id is not yet
# in allowed_users. The allowed_users gate must run after token handling.
state = "telegram-bind-state"
await repo.create_oauth_state(
owner_user_id="deerflow-user-1",
provider="telegram",
state=state,
expires_at=datetime.now(UTC) + timedelta(minutes=5),
)
channel = TelegramChannel(
bus=MessageBus(),
config={
"bot_token": "test-token",
"connection_repo": repo,
"allowed_users": [999], # newcomer (42) is not whitelisted
},
)
update = _telegram_update(text=f"/start {state}", user_id=42)
context = MagicMock()
context.args = [state]
await channel._cmd_start(update, context)
connections = await repo.list_connections("deerflow-user-1")
assert len(connections) == 1
assert connections[0]["external_account_id"] == "42"
assert "connected" in update.message.reply_text.await_args.args[0].lower()
@pytest.mark.anyio
async def test_bound_telegram_message_publishes_connection_identity(repo):
connection = await repo.upsert_connection(
+61
View File
@@ -7,6 +7,7 @@ import base64
import json
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage
@@ -359,6 +360,66 @@ def test_allowed_users_filter_blocks_non_whitelisted_sender():
_run(go())
def test_connect_code_bypasses_allowed_users_filter(tmp_path: Path):
from app.channels.wechat import WechatChannel
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
async def go():
from datetime import UTC, datetime, timedelta
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'wechat.db'}", sqlite_dir=str(tmp_path))
try:
repo = ChannelConnectionRepository(
get_session_factory(),
cipher=ChannelCredentialCipher.from_key("wechat-secret"),
)
code = "wechat-bind-code"
await repo.create_oauth_state(
owner_user_id="deerflow-user-1",
provider="wechat",
state=code,
expires_at=datetime.now(UTC) + timedelta(minutes=5),
)
bus = MessageBus()
published = []
async def capture(msg):
published.append(msg)
bus.publish_inbound = capture # type: ignore[method-assign]
# The newcomer ("blocked-user") is not in allowed_users yet, but a valid
# /connect code must still bootstrap their first bind.
channel = WechatChannel(
bus=bus,
config={"bot_token": "test-token", "allowed_users": ["allowed-user"], "connection_repo": repo},
)
channel._send_connection_reply = AsyncMock() # type: ignore[method-assign]
await channel._handle_update(
{
"message_type": 1,
"from_user_id": "blocked-user",
"context_token": "ctx-connect",
"item_list": [{"type": 1, "text_item": {"text": f"/connect {code}"}}],
}
)
connections = await repo.list_connections("deerflow-user-1")
assert len(connections) == 1
assert connections[0]["provider"] == "wechat"
assert connections[0]["external_account_id"] == "blocked-user"
# The connect-code reply was sent and no normal inbound was published.
channel._send_connection_reply.assert_awaited_once()
assert published == []
finally:
await close_engine()
_run(go())
def test_send_uses_cached_context_token(monkeypatch):
from app.channels.wechat import WechatChannel
@@ -0,0 +1,93 @@
import type { ChannelConnection, ChannelProviderId } from "./types";
export const CONNECT_POLL_INTERVAL_MS = 2000;
// Fallback bind window used when the backend response omits or garbles
// `expires_in`, so a non-finite value can never produce an unbounded poll loop.
const DEFAULT_CONNECT_EXPIRES_S = 600;
export interface ConnectPollHandle {
cancel: () => void;
}
export interface ConnectPollOptions {
provider: ChannelProviderId;
expiresInSeconds: number;
/** Fetch the latest connections — the single source of truth for "connected". */
fetchConnections: () => Promise<ChannelConnection[]>;
/** Invoked once when the provider's connection resolves to "connected". */
onConnected: () => void;
intervalMs?: number;
now?: () => number;
}
/**
* Poll the connections endpoint until the given provider reports `connected`
* or the bind window elapses. Returns a handle whose `cancel()` stops the loop
* (used to dedup repeated connects and to clean up on unmount).
*
* Only the connections endpoint is polled; `onConnected` lets the caller refresh
* derived provider state exactly once when the bind lands, instead of fetching
* both endpoints on every tick.
*/
export function startConnectionPoll(
options: ConnectPollOptions,
): ConnectPollHandle {
const {
provider,
expiresInSeconds,
fetchConnections,
onConnected,
intervalMs = CONNECT_POLL_INTERVAL_MS,
now = Date.now,
} = options;
const expires =
Number.isFinite(expiresInSeconds) && expiresInSeconds > 0
? expiresInSeconds
: DEFAULT_CONNECT_EXPIRES_S;
const deadline = now() + expires * 1000;
let timer: ReturnType<typeof setTimeout> | undefined;
let cancelled = false;
const cancel = () => {
cancelled = true;
if (timer !== undefined) {
clearTimeout(timer);
timer = undefined;
}
};
const schedule = () => {
timer = setTimeout(() => {
timer = undefined;
if (cancelled) {
return;
}
void fetchConnections()
.then((connections) => {
if (cancelled) {
return;
}
const connected = connections.some(
(item) => item.provider === provider && item.status === "connected",
);
if (connected) {
onConnected();
return;
}
if (now() < deadline) {
schedule();
}
})
.catch(() => {
if (!cancelled && now() < deadline) {
schedule();
}
});
}, intervalMs);
};
schedule();
return { cancel };
}
+38 -1
View File
@@ -1,4 +1,5 @@
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import { useEffect, useRef } from "react";
import {
configureChannelProvider,
@@ -8,6 +9,7 @@ import {
listChannelConnections,
listChannelProviders,
} from "./api";
import { startConnectionPoll, type ConnectPollHandle } from "./connect-poll";
import type { ChannelProviderId, ChannelRuntimeConfigValues } from "./types";
export const channelProviderQueryKey = ["channelProviders"] as const;
@@ -36,14 +38,49 @@ export function useChannelConnections() {
export function useConnectChannelProvider() {
const queryClient = useQueryClient();
const pollersRef = useRef<Map<ChannelProviderId, ConnectPollHandle>>(
new Map(),
);
// Cancel any in-flight polls when the component using this hook unmounts.
useEffect(() => {
const pollers = pollersRef.current;
return () => {
pollers.forEach((handle) => handle.cancel());
pollers.clear();
};
}, []);
return useMutation({
mutationFn: (provider: ChannelProviderId) =>
connectChannelProvider(provider),
onSuccess: () => {
onSuccess: (result, provider) => {
void queryClient.invalidateQueries({ queryKey: channelProviderQueryKey });
void queryClient.invalidateQueries({
queryKey: channelConnectionsQueryKey,
});
// Replace any existing poll for this provider so repeated Connect clicks
// don't spawn parallel polling chains racing on the same query keys.
pollersRef.current.get(provider)?.cancel();
pollersRef.current.set(
provider,
startConnectionPoll({
provider,
expiresInSeconds: result.expires_in,
fetchConnections: () =>
queryClient.fetchQuery({
queryKey: channelConnectionsQueryKey,
queryFn: () => listChannelConnections(),
}),
onConnected: () => {
// Refresh derived provider state exactly once when the bind lands.
void queryClient.invalidateQueries({
queryKey: channelProviderQueryKey,
});
},
}),
);
},
});
}
@@ -0,0 +1,101 @@
import { afterEach, beforeEach, describe, expect, test, vi } from "vitest";
import { startConnectionPoll } from "@/core/channels/connect-poll";
import type { ChannelConnection } from "@/core/channels/types";
function connection(provider: string, status: string): ChannelConnection {
return {
id: `${provider}-1`,
provider,
status,
scopes: [],
metadata: {},
};
}
beforeEach(() => {
vi.useFakeTimers();
});
afterEach(() => {
vi.useRealTimers();
});
describe("startConnectionPoll", () => {
test("polls connections until the provider is connected, then resolves once", async () => {
const responses: ChannelConnection[][] = [
[connection("telegram", "pending")],
[connection("telegram", "connected")],
];
const fetchConnections = vi.fn(async () => responses.shift() ?? []);
const onConnected = vi.fn();
startConnectionPoll({
provider: "telegram",
expiresInSeconds: 600,
fetchConnections,
onConnected,
intervalMs: 1000,
});
await vi.advanceTimersByTimeAsync(1000);
expect(fetchConnections).toHaveBeenCalledTimes(1);
expect(onConnected).not.toHaveBeenCalled();
await vi.advanceTimersByTimeAsync(1000);
expect(fetchConnections).toHaveBeenCalledTimes(2);
expect(onConnected).toHaveBeenCalledTimes(1);
// No further polling after the connection resolves.
await vi.advanceTimersByTimeAsync(5000);
expect(fetchConnections).toHaveBeenCalledTimes(2);
});
test("cancel() stops scheduled polling and fires no further fetches", async () => {
const fetchConnections = vi.fn(async () => [
connection("telegram", "pending"),
]);
const handle = startConnectionPoll({
provider: "telegram",
expiresInSeconds: 600,
fetchConnections,
onConnected: vi.fn(),
intervalMs: 1000,
});
await vi.advanceTimersByTimeAsync(1000);
expect(fetchConnections).toHaveBeenCalledTimes(1);
handle.cancel();
await vi.advanceTimersByTimeAsync(10000);
expect(fetchConnections).toHaveBeenCalledTimes(1);
});
test("a non-finite expires_in falls back to a finite deadline and terminates", async () => {
const fetchConnections = vi.fn(async () => [
connection("telegram", "pending"),
]);
let nowValue = 0;
startConnectionPoll({
provider: "telegram",
expiresInSeconds: Number.NaN,
fetchConnections,
onConnected: vi.fn(),
intervalMs: 1000,
now: () => nowValue,
});
nowValue = 1;
await vi.advanceTimersByTimeAsync(1000);
expect(fetchConnections).toHaveBeenCalledTimes(1);
// Jump past the fallback expiry window: the loop must stop instead of
// running forever (Date.now() >= NaN would otherwise never be true).
nowValue = 10_000_000;
await vi.advanceTimersByTimeAsync(1000);
expect(fetchConnections).toHaveBeenCalledTimes(2);
await vi.advanceTimersByTimeAsync(10000);
expect(fetchConnections).toHaveBeenCalledTimes(2);
});
});