Files
deer-flow/backend/packages/harness/deerflow/persistence/channel_connections/sql.py
T
Nan Gao 68ba4198b8 fix(channels): make channel connect flow deterministic (#3582)
* fix(channels): make channel connect flow deterministic

* make format

* fix(channels): apply connect-code before allowed_users on telegram and wechat

The bind-bootstrap reorder shipped for slack/dingtalk only. Telegram and
WeChat still gate _check_user/allowed_users before connect-code dispatch, so
a newly allowlisted-but-unbound user is silently rejected when binding via the
browser deep-link / connect-code flow — the same deadlock the PR fixes.

- telegram: consume the /start deep-link token before the allowed_users gate.
- wechat: handle the /connect code before the allowed_users gate, and defer
  inbound file extraction + context-token tracking past the gate so blocked
  senders no longer trigger CDN downloads or token bookkeeping.

Adds regression tests for both adapters mirroring the slack/dingtalk coverage.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* fix(channels): enforce single-active-owner invariant at the DB layer

_revoke_other_active_owners did a SELECT-then-UPDATE in app code with no row
lock or constraint covering active rows. Under READ COMMITTED, two concurrent
connect-code consumes for the same (provider, external_account_id, workspace_id)
from different owners could each observe "no other active owner" and both commit
a connected row, leaving find_connection_by_external_identity nondeterministic.

- Add a partial unique index on (provider, external_account_id, workspace_id)
  WHERE status != 'revoked' (portable to SQLite >= 3.8.0 and PostgreSQL) so the
  database guarantees at most one non-revoked row per external identity.
- Reorder upsert_connection to revoke other owners' active rows before the new
  connected row is flushed (so the index is satisfied at commit), wrapped in a
  bounded rollback-and-retry loop. A losing concurrent writer now retries
  against the now-visible state instead of committing a duplicate.

Adds DB-constraint, revoked-slot-reuse, and concurrent-upsert regression tests.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* fix(channels): harden connect-status polling primitive

pollChannelConnectionUntilResolved was a free-floating recursive setTimeout
started from onSuccess with no cancellation, no per-provider dedup, a redundant
second endpoint per tick, and an unbounded loop on a non-finite expires_in.

- Extract a framework-agnostic, cancellable poller (connect-poll.ts) that polls
  only listChannelConnections() and invalidates the providers query once when the
  bind resolves, instead of fetching both endpoints every tick.
- Guard expires_in with a finite check + default window so undefined/NaN can no
  longer produce a poll loop that runs until the page closes.
- Track one active poll handle per provider in useConnectChannelProvider via a
  ref Map: a new connect cancels the prior poll for that provider, and a useEffect
  cleanup cancels all polls on unmount.

Adds unit tests for resolve-and-stop, cancellation, and non-finite-expiry.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* fix(channels): stop leaking blocked-sender content in DingTalk INFO log; document bind semantics

Moving the allowed_users gate past _extract_text meant the parsed-message INFO
log (text=%r, first 100 chars) fired for senders that allowed_users would have
rejected, defeating the filter's noise/privacy role. Move that log to after the
allowed_users gate so blocked senders' message text never reaches INFO logs.

Also document the two operator-relevant semantic changes in backend/CLAUDE.md:
connect-code dispatch runs before allowed_users (so allowed_users is no longer a
bind-time defense; the model relies on code confidentiality + 600s TTL + one-time
consumption), and the single-active-owner-per-external-identity transfer semantics
now backed by the partial unique index.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* docs(channels): note connect-code-vs-allowlist and ownership transfer in operator guide

Mirror the backend/CLAUDE.md notes in the operator-facing IM_CHANNEL_CONNECTIONS.md:
connect codes are consumed before allowed_users (so a not-yet-allowlisted user can
still complete a first bind, and allowed_users is not a bind-time defense), and an
external identity has at most one active owner with last-bind-wins transfer enforced
at the DB layer.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* refactor(channels): lift connect-code dispatch into Channel base class

Each adapter duplicated the ordering-sensitive boilerplate of extracting a
/connect code and guarding on the connection repo before its allowed_users gate.
The duplication is what let telegram/wechat drift and keep the gate ahead of the
bind. Centralize it:

- Move `_connection_repo` onto Channel.__init__ (removing 7 duplicate assignments).
- Add Channel._pending_connect_code(text), which guards on the repo and extracts
  the code, documenting that adapters MUST consult it before authorization so a
  browser-initiated bind can bootstrap a not-yet-authorized identity.
- Route slack, discord, feishu, dingtalk, wechat, and wecom through the helper.
  This also fixes a latent inconsistency where slack dispatched a bind even when
  no connection repo was configured.

Pure refactor — the full channel suite stays green; adds a direct unit test for
the base helper's contract.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* make format

* fix(channels): redact DingTalk parsed-message INFO log content

Log text_len instead of the first 100 chars of message text, so message
content never reaches INFO logs (the after-gate move already keeps blocked
senders out entirely). This takes over the redaction from #3584 so only this
PR touches dingtalk.py, letting the two PRs merge in any order conflict-free.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-18 10:15:31 +08:00

550 lines
24 KiB
Python

"""SQL repository for user-owned IM channel connections."""
from __future__ import annotations
import base64
import hashlib
import json
import logging
import uuid
from datetime import UTC, datetime
from typing import Any
from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import delete, func, select, text, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.channel_connections.model import (
ChannelConnectionRow,
ChannelConversationRow,
ChannelCredentialRow,
ChannelOAuthStateRow,
)
from deerflow.utils.time import coerce_iso
logger = logging.getLogger(__name__)
# Bounded retries for upsert_connection when a concurrent writer commits a
# conflicting row first (same owner identity, or the same active external
# identity guarded by the partial unique index). Each retry re-reads the
# now-visible state, so a small bound converges under realistic contention.
_UPSERT_MAX_ATTEMPTS = 3
class ChannelCredentialCipher:
"""Encrypts provider credentials before they are persisted."""
def __init__(self, fernet: Fernet) -> None:
self._fernet = fernet
@classmethod
def from_key(cls, key: str) -> ChannelCredentialCipher:
digest = hashlib.sha256(key.encode("utf-8")).digest()
return cls(Fernet(base64.urlsafe_b64encode(digest)))
def encrypt_text(self, value: str | None) -> str | None:
if value is None:
return None
return "fernet:v1:" + self._fernet.encrypt(value.encode("utf-8")).decode("ascii")
def decrypt_text(self, value: str | None) -> str | None:
if value is None:
return None
token = value.removeprefix("fernet:v1:")
return self._fernet.decrypt(token.encode("ascii")).decode("utf-8")
class ChannelConnectionRepository:
"""Persistence facade for channel connections, credentials, and conversations."""
def __init__(
self,
session_factory: async_sessionmaker[AsyncSession],
*,
cipher: ChannelCredentialCipher | None = None,
) -> None:
self.session_factory = session_factory
self._cipher = cipher
async def close(self) -> None:
from deerflow.persistence.engine import close_engine
await close_engine()
@staticmethod
def _new_id() -> str:
return uuid.uuid4().hex
@staticmethod
def _normalize_optional_identity(value: str | None) -> str:
return value or ""
@staticmethod
def _coerce_datetime(value: datetime | None) -> datetime | None:
if value is None or value.tzinfo is not None:
return value
return value.replace(tzinfo=UTC)
def _encrypt_optional_secret(self, value: str | None) -> str | None:
if value is None:
return None
if self._cipher is None:
raise RuntimeError("channel connection encryption key is required")
return self._cipher.encrypt_text(value)
@staticmethod
def _connection_to_dict(row: ChannelConnectionRow) -> dict[str, Any]:
data = row.to_dict()
data["external_account_id"] = data["external_account_id"] or None
data["workspace_id"] = data["workspace_id"] or None
data["scopes"] = data.pop("scopes_json") or []
data["capabilities"] = data.pop("capabilities_json") or {}
data["metadata"] = data.pop("metadata_json") or {}
for key in ("created_at", "updated_at", "last_seen_at", "last_error_at"):
value = data.get(key)
if isinstance(value, datetime):
data[key] = coerce_iso(value)
return data
async def upsert_connection(
self,
*,
owner_user_id: str,
provider: str,
external_account_id: str | None = None,
external_account_name: str | None = None,
workspace_id: str | None = None,
workspace_name: str | None = None,
bot_user_id: str | None = None,
scopes: list[str] | None = None,
capabilities: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
status: str = "connected",
) -> dict[str, Any]:
external_account_id_value = self._normalize_optional_identity(external_account_id)
workspace_id_value = self._normalize_optional_identity(workspace_id)
def _apply(row: ChannelConnectionRow) -> None:
row.status = status
row.external_account_name = external_account_name
row.workspace_name = workspace_name
row.bot_user_id = bot_user_id
row.scopes_json = list(scopes or [])
row.capabilities_json = dict(capabilities or {})
row.metadata_json = dict(metadata or {})
async def _revoke_other_active_owners(session: AsyncSession) -> None:
if status != "connected":
return
with session.no_autoflush:
result = await session.execute(
select(ChannelConnectionRow.id).where(
ChannelConnectionRow.provider == provider,
ChannelConnectionRow.external_account_id == external_account_id_value,
ChannelConnectionRow.workspace_id == workspace_id_value,
ChannelConnectionRow.owner_user_id != owner_user_id,
ChannelConnectionRow.status != "revoked",
)
)
transferred_ids = [row_id for row_id in result.scalars()]
if not transferred_ids:
return
await session.execute(update(ChannelConnectionRow).where(ChannelConnectionRow.id.in_(transferred_ids)).values(status="revoked"))
await session.execute(delete(ChannelCredentialRow).where(ChannelCredentialRow.connection_id.in_(transferred_ids)))
stmt = select(ChannelConnectionRow).where(
ChannelConnectionRow.owner_user_id == owner_user_id,
ChannelConnectionRow.provider == provider,
ChannelConnectionRow.external_account_id == external_account_id_value,
ChannelConnectionRow.workspace_id == workspace_id_value,
)
async with self.session_factory() as session:
last_error: IntegrityError | None = None
for _ in range(_UPSERT_MAX_ATTEMPTS):
try:
row = (await session.execute(stmt)).scalar_one_or_none()
# Revoke any other owner's active row for this external identity
# *before* our connected row is flushed, so the partial unique
# index on active identities is satisfied at commit time.
await _revoke_other_active_owners(session)
if row is None:
row = ChannelConnectionRow(
id=self._new_id(),
owner_user_id=owner_user_id,
provider=provider,
external_account_id=external_account_id_value,
workspace_id=workspace_id_value,
)
session.add(row)
_apply(row)
await session.commit()
await session.refresh(row)
return self._connection_to_dict(row)
except IntegrityError as exc:
# A concurrent writer committed a conflicting row first (this
# owner's identity, or the same active external identity). Roll
# back and retry: the next pass re-reads the now-visible state,
# revokes the newly-committed owner, and writes our row.
last_error = exc
await session.rollback()
raise last_error # type: ignore[misc] # loop runs at least once
async def list_connections(self, owner_user_id: str) -> list[dict[str, Any]]:
async with self.session_factory() as session:
result = await session.execute(select(ChannelConnectionRow).where(ChannelConnectionRow.owner_user_id == owner_user_id).order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc()))
return [self._connection_to_dict(row) for row in result.scalars()]
async def disconnect_connection(self, *, connection_id: str, owner_user_id: str) -> bool:
async with self.session_factory() as session:
row = await session.get(ChannelConnectionRow, connection_id)
if row is None or row.owner_user_id != owner_user_id:
return False
row.status = "revoked"
credential = await session.get(ChannelCredentialRow, connection_id)
if credential is not None:
await session.delete(credential)
await session.commit()
return True
async def disconnect_provider_connections(self, *, provider: str) -> int:
"""Revoke all active user connections for an instance-wide provider removal."""
async with self.session_factory() as session:
result = await session.execute(
select(ChannelConnectionRow.id).where(
ChannelConnectionRow.provider == provider,
ChannelConnectionRow.status != "revoked",
)
)
connection_ids = [row_id for row_id in result.scalars()]
if not connection_ids:
return 0
await session.execute(update(ChannelConnectionRow).where(ChannelConnectionRow.id.in_(connection_ids)).values(status="revoked"))
await session.execute(delete(ChannelCredentialRow).where(ChannelCredentialRow.connection_id.in_(connection_ids)))
await session.commit()
return len(connection_ids)
async def store_credentials(
self,
connection_id: str,
*,
access_token: str | None,
refresh_token: str | None = None,
token_type: str | None = None,
expires_at: datetime | None = None,
refresh_expires_at: datetime | None = None,
extra: dict[str, Any] | None = None,
) -> None:
if self._cipher is None:
raise RuntimeError("channel connection encryption key is required")
async with self.session_factory() as session:
row = await session.get(ChannelCredentialRow, connection_id)
if row is None:
row = ChannelCredentialRow(connection_id=connection_id)
session.add(row)
row.encrypted_access_token = self._cipher.encrypt_text(access_token)
row.encrypted_refresh_token = self._cipher.encrypt_text(refresh_token)
row.token_type = token_type
row.expires_at = expires_at
row.refresh_expires_at = refresh_expires_at
row.encrypted_extra_json = self._cipher.encrypt_text(json.dumps(extra or {}, ensure_ascii=False))
row.version = (row.version or 0) + 1
await session.commit()
async def get_credentials(self, connection_id: str) -> dict[str, Any] | None:
if self._cipher is None:
return None
async with self.session_factory() as session:
row = await session.get(ChannelCredentialRow, connection_id)
if row is None:
return None
try:
extra_raw = self._cipher.decrypt_text(row.encrypted_extra_json)
return {
"connection_id": row.connection_id,
"access_token": self._cipher.decrypt_text(row.encrypted_access_token),
"refresh_token": self._cipher.decrypt_text(row.encrypted_refresh_token),
"token_type": row.token_type,
"expires_at": self._coerce_datetime(row.expires_at),
"refresh_expires_at": self._coerce_datetime(row.refresh_expires_at),
"extra": json.loads(extra_raw) if extra_raw else {},
}
except (InvalidToken, UnicodeError, json.JSONDecodeError):
logger.warning(
"Unable to decrypt channel connection credentials; treating credentials as unavailable",
exc_info=True,
)
return None
@staticmethod
def hash_state(state: str) -> str:
return hashlib.sha256(state.encode("utf-8")).hexdigest()
async def create_oauth_state(
self,
*,
owner_user_id: str,
provider: str,
state: str,
expires_at: datetime,
code_verifier: str | None = None,
nonce_hash: str | None = None,
redirect_after: str | None = None,
requested_scopes: list[str] | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
row = ChannelOAuthStateRow(
state_hash=self.hash_state(state),
owner_user_id=owner_user_id,
provider=provider,
code_verifier_encrypted=self._encrypt_optional_secret(code_verifier),
nonce_hash=nonce_hash,
redirect_after=redirect_after,
requested_scopes_json=list(requested_scopes or []),
metadata_json=dict(metadata or {}),
expires_at=expires_at,
)
async with self.session_factory() as session:
session.add(row)
await session.commit()
async def create_oauth_state_within_cap(
self,
*,
owner_user_id: str,
provider: str,
state: str,
expires_at: datetime,
max_pending: int,
now: datetime | None = None,
code_verifier: str | None = None,
nonce_hash: str | None = None,
redirect_after: str | None = None,
requested_scopes: list[str] | None = None,
metadata: dict[str, Any] | None = None,
) -> bool:
"""Atomically enforce the per-(owner, provider) pending cap, then insert.
delete-expired + count + insert run in a single transaction serialized
per (owner, provider), so concurrent connect requests cannot each
observe ``count < max_pending`` and all insert (which would leak past
the cap). PostgreSQL takes a transaction-scoped advisory lock; SQLite
serializes writers through the write lock the leading DELETE acquires.
Returns ``True`` when the row was inserted, ``False`` when the cap is
already reached.
"""
current_time = now or datetime.now(UTC)
async with self.session_factory() as session:
await self._serialize_oauth_owner_scope(session, owner_user_id, provider)
# Prune only this owner/provider's expired codes (the ones that affect
# this cap), not every user's — avoids a global DELETE on each connect
# POST. Issuing this write first also takes the SQLite database write
# lock so the count below cannot race a concurrent inserter between
# count and commit. Stale codes for other owners are pruned globally
# by consume_oauth_state / delete_expired_oauth_states.
await session.execute(
delete(ChannelOAuthStateRow).where(
ChannelOAuthStateRow.owner_user_id == owner_user_id,
ChannelOAuthStateRow.provider == provider,
ChannelOAuthStateRow.expires_at < current_time,
)
)
pending = await session.execute(
select(func.count())
.select_from(ChannelOAuthStateRow)
.where(
ChannelOAuthStateRow.owner_user_id == owner_user_id,
ChannelOAuthStateRow.provider == provider,
ChannelOAuthStateRow.consumed_at.is_(None),
ChannelOAuthStateRow.expires_at >= current_time,
)
)
if int(pending.scalar_one()) >= max_pending:
await session.rollback()
return False
session.add(
ChannelOAuthStateRow(
state_hash=self.hash_state(state),
owner_user_id=owner_user_id,
provider=provider,
code_verifier_encrypted=self._encrypt_optional_secret(code_verifier),
nonce_hash=nonce_hash,
redirect_after=redirect_after,
requested_scopes_json=list(requested_scopes or []),
metadata_json=dict(metadata or {}),
expires_at=expires_at,
)
)
await session.commit()
return True
async def _serialize_oauth_owner_scope(self, session: AsyncSession, owner_user_id: str, provider: str) -> None:
"""Serialize concurrent pending-cap transactions for one (owner, provider).
On PostgreSQL this takes a transaction-scoped advisory lock so concurrent
issuers run their count+insert one at a time. On SQLite the leading
DELETE in the caller's transaction already acquires the database write
lock, which serializes writers, so no extra lock is required.
"""
try:
dialect = session.bind.dialect.name if session.bind is not None else ""
except Exception:
dialect = ""
if dialect == "postgresql":
await session.execute(text("SELECT pg_advisory_xact_lock(:lock_key)"), {"lock_key": self._oauth_scope_lock_key(owner_user_id, provider)})
@staticmethod
def _oauth_scope_lock_key(owner_user_id: str, provider: str) -> int:
digest = hashlib.sha256(f"{owner_user_id}\x00{provider}".encode()).digest()
# 63-bit non-negative key for pg_advisory_xact_lock(bigint).
return int.from_bytes(digest[:8], "big") & 0x7FFFFFFFFFFFFFFF
async def delete_expired_oauth_states(self, *, now: datetime | None = None) -> int:
current_time = now or datetime.now(UTC)
async with self.session_factory() as session:
result = await session.execute(delete(ChannelOAuthStateRow).where(ChannelOAuthStateRow.expires_at < current_time))
await session.commit()
return int(result.rowcount or 0)
async def count_oauth_states(
self,
*,
owner_user_id: str,
provider: str,
active_only: bool = False,
now: datetime | None = None,
) -> int:
current_time = now or datetime.now(UTC)
conditions = [
ChannelOAuthStateRow.owner_user_id == owner_user_id,
ChannelOAuthStateRow.provider == provider,
]
if active_only:
conditions.extend(
[
ChannelOAuthStateRow.consumed_at.is_(None),
ChannelOAuthStateRow.expires_at >= current_time,
]
)
async with self.session_factory() as session:
result = await session.execute(select(func.count()).select_from(ChannelOAuthStateRow).where(*conditions))
return int(result.scalar_one())
async def consume_oauth_state(
self,
*,
provider: str,
state: str,
now: datetime | None = None,
) -> dict[str, Any] | None:
current_time = now or datetime.now(UTC)
state_hash = self.hash_state(state)
async with self.session_factory() as session:
await session.execute(delete(ChannelOAuthStateRow).where(ChannelOAuthStateRow.expires_at < current_time))
row = await session.get(ChannelOAuthStateRow, state_hash)
if row is None or row.provider != provider or row.consumed_at is not None:
await session.commit()
return None
expires_at = self._coerce_datetime(row.expires_at)
if expires_at is not None and expires_at < current_time:
await session.commit()
return None
# Conditional UPDATE so two concurrent workers cannot both consume
# the same binding code: only the writer that flips consumed_at
# from NULL wins.
result = await session.execute(
update(ChannelOAuthStateRow)
.where(
ChannelOAuthStateRow.state_hash == state_hash,
ChannelOAuthStateRow.consumed_at.is_(None),
)
.values(consumed_at=current_time)
)
await session.commit()
if result.rowcount != 1:
return None
return {
"owner_user_id": row.owner_user_id,
"provider": row.provider,
"requested_scopes": row.requested_scopes_json or [],
"metadata": row.metadata_json or {},
"redirect_after": row.redirect_after,
}
async def find_connection_by_external_identity(
self,
*,
provider: str,
external_account_id: str,
workspace_id: str | None = None,
) -> dict[str, Any] | None:
async with self.session_factory() as session:
result = await session.execute(
select(ChannelConnectionRow)
.where(
ChannelConnectionRow.provider == provider,
ChannelConnectionRow.external_account_id == self._normalize_optional_identity(external_account_id),
ChannelConnectionRow.workspace_id == self._normalize_optional_identity(workspace_id),
ChannelConnectionRow.status == "connected",
)
.order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc())
.limit(1)
)
row = result.scalar_one_or_none()
return self._connection_to_dict(row) if row is not None else None
async def set_thread_id(
self,
*,
connection_id: str,
owner_user_id: str,
provider: str,
external_conversation_id: str,
thread_id: str,
external_topic_id: str | None = None,
) -> None:
topic_id = external_topic_id or ""
async with self.session_factory() as session:
stmt = select(ChannelConversationRow).where(
ChannelConversationRow.connection_id == connection_id,
ChannelConversationRow.external_conversation_id == external_conversation_id,
ChannelConversationRow.external_topic_id == topic_id,
)
row = (await session.execute(stmt)).scalar_one_or_none()
if row is None:
row = ChannelConversationRow(
id=self._new_id(),
connection_id=connection_id,
owner_user_id=owner_user_id,
provider=provider,
external_conversation_id=external_conversation_id,
external_topic_id=topic_id,
thread_id=thread_id,
)
session.add(row)
else:
row.thread_id = thread_id
row.owner_user_id = owner_user_id
row.provider = provider
await session.commit()
async def get_thread_id(
self,
connection_id: str,
external_conversation_id: str,
external_topic_id: str | None = None,
) -> str | None:
async with self.session_factory() as session:
stmt = select(ChannelConversationRow.thread_id).where(
ChannelConversationRow.connection_id == connection_id,
ChannelConversationRow.external_conversation_id == external_conversation_id,
ChannelConversationRow.external_topic_id == (external_topic_id or ""),
)
return (await session.execute(stmt)).scalar_one_or_none()