Files
deer-flow/backend/packages/harness/deerflow/persistence/channel_connections/sql.py
T
Nan Gao 8c0830aea1 fix(channels): add operational guardrails (#3584)
* fix(channels): add operational guardrails

* make format

* fix(channels): converge with #3582 to avoid merge-order conflicts

Drop this PR's DingTalk INFO-log redaction and hand it to #3582, which
already restructures that handler and will redact the same log there. This
PR no longer touches dingtalk.py, so the two PRs can merge to main in any
order without a conflict.

For WeChat, drop the contested thread_ts priority reorder (review #3) and
keep only what inbound dedupe needs: a server-stable message_id in the
inbound metadata (message_id/msg_id, no client_id per review #6). This is a
single added line inside the metadata dict, a region #3582 never touches, so
it auto-merges regardless of order.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* fix(channels): address three correctness review findings

1. Connect-code cap was racy (willem #1): _create_state ran delete-expired,
   count, and insert as three separate transactions, so concurrent connect
   POSTs from one owner could each see count < cap and all insert past it. Add
   ChannelConnectionRepository.create_oauth_state_within_cap which does
   delete+count+insert in a single transaction serialized per (owner,
   provider) — Postgres via pg_advisory_xact_lock, SQLite via the write lock
   the leading DELETE takes — and have the router use it.

2. Inbound dedupe key fell back to "" workspace (willem #3): two workspaces
   delivering without team/guild/aibotid would collapse to the same key and
   dedupe each other's messages. _inbound_dedupe_key now fails closed
   (returns None) when no workspace identifier is present.

3. Dedupe key was recorded on receipt and never released on failure
   (ShenAC #1): a transient error (DB blip, Gateway 503) left the key in place
   for the full TTL, so a provider redelivery of the same message_id — exactly
   the retry dedupe should absorb — was silently dropped. _handle_message now
   releases the key in the unexpected-exception branch so redelivery can
   recover, while keeping record-on-receipt so retries during handling are
   still deduped.

Tests: repo cap enforcement incl. concurrent-issuance non-leak; dedupe
fail-closed; dedupe key release-on-failure redelivery recovery.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* fix(channels): address cleanup/efficiency and test review findings

Efficiency / cleanup:
- Dedupe key set drops client-generated ids (client_msg_id, client_id);
  keep only server-stable event_id/message_id/msg_id, which a provider's own
  redelivery preserves (ShenAC #6). Every provider already emits message_id.
- TTL/overflow pruning of _recent_inbound_events is now O(k): switch to an
  OrderedDict and popitem(last=False) from the front instead of scanning all
  4096 entries on every inbound (willem #4).
- Log "received inbound" only after the dedupe check so a provider retrying N
  times no longer logs N accepts; document that manager dedupe covers the
  agent run/final answer, not provider ack side-effects (willem #5, ShenAC #2).
- Slack drops the redundant `team_id or event.get("team")` fallback the caller
  already resolved (willem #6).
- create_oauth_state_within_cap prunes only this owner/provider's expired codes
  instead of a global DELETE on every connect POST; global cleanup still runs
  on consume_oauth_state (willem #7).

Tests:
- Dedupe test uses tmp_path instead of a leaked mkdtemp, uses distinct objects
  per publish, and adds a negative control: a different message_id is still
  processed, catching over-dedupe regressions (willem #8, ShenAC #4).
- Slack HTTP-mode rejection test supplies app_token so the missing-token early
  return can't mask the guard, giving the state assertions teeth (ShenAC #3).
- count_oauth_states test pins that the active row survives, not just the count
  (ShenAC #5).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* make format

---------

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-18 10:09:46 +08:00

518 lines
21 KiB
Python

"""SQL repository for user-owned IM channel connections."""
from __future__ import annotations
import base64
import hashlib
import json
import logging
import uuid
from datetime import UTC, datetime
from typing import Any
from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import delete, func, select, text, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.channel_connections.model import (
ChannelConnectionRow,
ChannelConversationRow,
ChannelCredentialRow,
ChannelOAuthStateRow,
)
from deerflow.utils.time import coerce_iso
logger = logging.getLogger(__name__)
class ChannelCredentialCipher:
"""Encrypts provider credentials before they are persisted."""
def __init__(self, fernet: Fernet) -> None:
self._fernet = fernet
@classmethod
def from_key(cls, key: str) -> ChannelCredentialCipher:
digest = hashlib.sha256(key.encode("utf-8")).digest()
return cls(Fernet(base64.urlsafe_b64encode(digest)))
def encrypt_text(self, value: str | None) -> str | None:
if value is None:
return None
return "fernet:v1:" + self._fernet.encrypt(value.encode("utf-8")).decode("ascii")
def decrypt_text(self, value: str | None) -> str | None:
if value is None:
return None
token = value.removeprefix("fernet:v1:")
return self._fernet.decrypt(token.encode("ascii")).decode("utf-8")
class ChannelConnectionRepository:
"""Persistence facade for channel connections, credentials, and conversations."""
def __init__(
self,
session_factory: async_sessionmaker[AsyncSession],
*,
cipher: ChannelCredentialCipher | None = None,
) -> None:
self.session_factory = session_factory
self._cipher = cipher
async def close(self) -> None:
from deerflow.persistence.engine import close_engine
await close_engine()
@staticmethod
def _new_id() -> str:
return uuid.uuid4().hex
@staticmethod
def _normalize_optional_identity(value: str | None) -> str:
return value or ""
@staticmethod
def _coerce_datetime(value: datetime | None) -> datetime | None:
if value is None or value.tzinfo is not None:
return value
return value.replace(tzinfo=UTC)
def _encrypt_optional_secret(self, value: str | None) -> str | None:
if value is None:
return None
if self._cipher is None:
raise RuntimeError("channel connection encryption key is required")
return self._cipher.encrypt_text(value)
@staticmethod
def _connection_to_dict(row: ChannelConnectionRow) -> dict[str, Any]:
data = row.to_dict()
data["external_account_id"] = data["external_account_id"] or None
data["workspace_id"] = data["workspace_id"] or None
data["scopes"] = data.pop("scopes_json") or []
data["capabilities"] = data.pop("capabilities_json") or {}
data["metadata"] = data.pop("metadata_json") or {}
for key in ("created_at", "updated_at", "last_seen_at", "last_error_at"):
value = data.get(key)
if isinstance(value, datetime):
data[key] = coerce_iso(value)
return data
async def upsert_connection(
self,
*,
owner_user_id: str,
provider: str,
external_account_id: str | None = None,
external_account_name: str | None = None,
workspace_id: str | None = None,
workspace_name: str | None = None,
bot_user_id: str | None = None,
scopes: list[str] | None = None,
capabilities: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
status: str = "connected",
) -> dict[str, Any]:
external_account_id_value = self._normalize_optional_identity(external_account_id)
workspace_id_value = self._normalize_optional_identity(workspace_id)
def _apply(row: ChannelConnectionRow) -> None:
row.status = status
row.external_account_name = external_account_name
row.workspace_name = workspace_name
row.bot_user_id = bot_user_id
row.scopes_json = list(scopes or [])
row.capabilities_json = dict(capabilities or {})
row.metadata_json = dict(metadata or {})
stmt = select(ChannelConnectionRow).where(
ChannelConnectionRow.owner_user_id == owner_user_id,
ChannelConnectionRow.provider == provider,
ChannelConnectionRow.external_account_id == external_account_id_value,
ChannelConnectionRow.workspace_id == workspace_id_value,
)
async with self.session_factory() as session:
row = (await session.execute(stmt)).scalar_one_or_none()
if row is None:
row = ChannelConnectionRow(
id=self._new_id(),
owner_user_id=owner_user_id,
provider=provider,
external_account_id=external_account_id_value,
workspace_id=workspace_id_value,
)
session.add(row)
_apply(row)
try:
await session.commit()
except IntegrityError:
# A concurrent writer inserted the same identity first; retry as
# an update of that row.
await session.rollback()
row = (await session.execute(stmt)).scalar_one()
_apply(row)
await session.commit()
await session.refresh(row)
return self._connection_to_dict(row)
async def list_connections(self, owner_user_id: str) -> list[dict[str, Any]]:
async with self.session_factory() as session:
result = await session.execute(select(ChannelConnectionRow).where(ChannelConnectionRow.owner_user_id == owner_user_id).order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc()))
return [self._connection_to_dict(row) for row in result.scalars()]
async def disconnect_connection(self, *, connection_id: str, owner_user_id: str) -> bool:
async with self.session_factory() as session:
row = await session.get(ChannelConnectionRow, connection_id)
if row is None or row.owner_user_id != owner_user_id:
return False
row.status = "revoked"
credential = await session.get(ChannelCredentialRow, connection_id)
if credential is not None:
await session.delete(credential)
await session.commit()
return True
async def disconnect_provider_connections(self, *, provider: str) -> int:
"""Revoke all active user connections for an instance-wide provider removal."""
async with self.session_factory() as session:
result = await session.execute(
select(ChannelConnectionRow.id).where(
ChannelConnectionRow.provider == provider,
ChannelConnectionRow.status != "revoked",
)
)
connection_ids = [row_id for row_id in result.scalars()]
if not connection_ids:
return 0
await session.execute(update(ChannelConnectionRow).where(ChannelConnectionRow.id.in_(connection_ids)).values(status="revoked"))
await session.execute(delete(ChannelCredentialRow).where(ChannelCredentialRow.connection_id.in_(connection_ids)))
await session.commit()
return len(connection_ids)
async def store_credentials(
self,
connection_id: str,
*,
access_token: str | None,
refresh_token: str | None = None,
token_type: str | None = None,
expires_at: datetime | None = None,
refresh_expires_at: datetime | None = None,
extra: dict[str, Any] | None = None,
) -> None:
if self._cipher is None:
raise RuntimeError("channel connection encryption key is required")
async with self.session_factory() as session:
row = await session.get(ChannelCredentialRow, connection_id)
if row is None:
row = ChannelCredentialRow(connection_id=connection_id)
session.add(row)
row.encrypted_access_token = self._cipher.encrypt_text(access_token)
row.encrypted_refresh_token = self._cipher.encrypt_text(refresh_token)
row.token_type = token_type
row.expires_at = expires_at
row.refresh_expires_at = refresh_expires_at
row.encrypted_extra_json = self._cipher.encrypt_text(json.dumps(extra or {}, ensure_ascii=False))
row.version = (row.version or 0) + 1
await session.commit()
async def get_credentials(self, connection_id: str) -> dict[str, Any] | None:
if self._cipher is None:
return None
async with self.session_factory() as session:
row = await session.get(ChannelCredentialRow, connection_id)
if row is None:
return None
try:
extra_raw = self._cipher.decrypt_text(row.encrypted_extra_json)
return {
"connection_id": row.connection_id,
"access_token": self._cipher.decrypt_text(row.encrypted_access_token),
"refresh_token": self._cipher.decrypt_text(row.encrypted_refresh_token),
"token_type": row.token_type,
"expires_at": self._coerce_datetime(row.expires_at),
"refresh_expires_at": self._coerce_datetime(row.refresh_expires_at),
"extra": json.loads(extra_raw) if extra_raw else {},
}
except (InvalidToken, UnicodeError, json.JSONDecodeError):
logger.warning(
"Unable to decrypt channel connection credentials; treating credentials as unavailable",
exc_info=True,
)
return None
@staticmethod
def hash_state(state: str) -> str:
return hashlib.sha256(state.encode("utf-8")).hexdigest()
async def create_oauth_state(
self,
*,
owner_user_id: str,
provider: str,
state: str,
expires_at: datetime,
code_verifier: str | None = None,
nonce_hash: str | None = None,
redirect_after: str | None = None,
requested_scopes: list[str] | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
row = ChannelOAuthStateRow(
state_hash=self.hash_state(state),
owner_user_id=owner_user_id,
provider=provider,
code_verifier_encrypted=self._encrypt_optional_secret(code_verifier),
nonce_hash=nonce_hash,
redirect_after=redirect_after,
requested_scopes_json=list(requested_scopes or []),
metadata_json=dict(metadata or {}),
expires_at=expires_at,
)
async with self.session_factory() as session:
session.add(row)
await session.commit()
async def create_oauth_state_within_cap(
self,
*,
owner_user_id: str,
provider: str,
state: str,
expires_at: datetime,
max_pending: int,
now: datetime | None = None,
code_verifier: str | None = None,
nonce_hash: str | None = None,
redirect_after: str | None = None,
requested_scopes: list[str] | None = None,
metadata: dict[str, Any] | None = None,
) -> bool:
"""Atomically enforce the per-(owner, provider) pending cap, then insert.
delete-expired + count + insert run in a single transaction serialized
per (owner, provider), so concurrent connect requests cannot each
observe ``count < max_pending`` and all insert (which would leak past
the cap). PostgreSQL takes a transaction-scoped advisory lock; SQLite
serializes writers through the write lock the leading DELETE acquires.
Returns ``True`` when the row was inserted, ``False`` when the cap is
already reached.
"""
current_time = now or datetime.now(UTC)
async with self.session_factory() as session:
await self._serialize_oauth_owner_scope(session, owner_user_id, provider)
# Prune only this owner/provider's expired codes (the ones that affect
# this cap), not every user's — avoids a global DELETE on each connect
# POST. Issuing this write first also takes the SQLite database write
# lock so the count below cannot race a concurrent inserter between
# count and commit. Stale codes for other owners are pruned globally
# by consume_oauth_state / delete_expired_oauth_states.
await session.execute(
delete(ChannelOAuthStateRow).where(
ChannelOAuthStateRow.owner_user_id == owner_user_id,
ChannelOAuthStateRow.provider == provider,
ChannelOAuthStateRow.expires_at < current_time,
)
)
pending = await session.execute(
select(func.count())
.select_from(ChannelOAuthStateRow)
.where(
ChannelOAuthStateRow.owner_user_id == owner_user_id,
ChannelOAuthStateRow.provider == provider,
ChannelOAuthStateRow.consumed_at.is_(None),
ChannelOAuthStateRow.expires_at >= current_time,
)
)
if int(pending.scalar_one()) >= max_pending:
await session.rollback()
return False
session.add(
ChannelOAuthStateRow(
state_hash=self.hash_state(state),
owner_user_id=owner_user_id,
provider=provider,
code_verifier_encrypted=self._encrypt_optional_secret(code_verifier),
nonce_hash=nonce_hash,
redirect_after=redirect_after,
requested_scopes_json=list(requested_scopes or []),
metadata_json=dict(metadata or {}),
expires_at=expires_at,
)
)
await session.commit()
return True
async def _serialize_oauth_owner_scope(self, session: AsyncSession, owner_user_id: str, provider: str) -> None:
"""Serialize concurrent pending-cap transactions for one (owner, provider).
On PostgreSQL this takes a transaction-scoped advisory lock so concurrent
issuers run their count+insert one at a time. On SQLite the leading
DELETE in the caller's transaction already acquires the database write
lock, which serializes writers, so no extra lock is required.
"""
try:
dialect = session.bind.dialect.name if session.bind is not None else ""
except Exception:
dialect = ""
if dialect == "postgresql":
await session.execute(text("SELECT pg_advisory_xact_lock(:lock_key)"), {"lock_key": self._oauth_scope_lock_key(owner_user_id, provider)})
@staticmethod
def _oauth_scope_lock_key(owner_user_id: str, provider: str) -> int:
digest = hashlib.sha256(f"{owner_user_id}\x00{provider}".encode()).digest()
# 63-bit non-negative key for pg_advisory_xact_lock(bigint).
return int.from_bytes(digest[:8], "big") & 0x7FFFFFFFFFFFFFFF
async def delete_expired_oauth_states(self, *, now: datetime | None = None) -> int:
current_time = now or datetime.now(UTC)
async with self.session_factory() as session:
result = await session.execute(delete(ChannelOAuthStateRow).where(ChannelOAuthStateRow.expires_at < current_time))
await session.commit()
return int(result.rowcount or 0)
async def count_oauth_states(
self,
*,
owner_user_id: str,
provider: str,
active_only: bool = False,
now: datetime | None = None,
) -> int:
current_time = now or datetime.now(UTC)
conditions = [
ChannelOAuthStateRow.owner_user_id == owner_user_id,
ChannelOAuthStateRow.provider == provider,
]
if active_only:
conditions.extend(
[
ChannelOAuthStateRow.consumed_at.is_(None),
ChannelOAuthStateRow.expires_at >= current_time,
]
)
async with self.session_factory() as session:
result = await session.execute(select(func.count()).select_from(ChannelOAuthStateRow).where(*conditions))
return int(result.scalar_one())
async def consume_oauth_state(
self,
*,
provider: str,
state: str,
now: datetime | None = None,
) -> dict[str, Any] | None:
current_time = now or datetime.now(UTC)
state_hash = self.hash_state(state)
async with self.session_factory() as session:
await session.execute(delete(ChannelOAuthStateRow).where(ChannelOAuthStateRow.expires_at < current_time))
row = await session.get(ChannelOAuthStateRow, state_hash)
if row is None or row.provider != provider or row.consumed_at is not None:
await session.commit()
return None
expires_at = self._coerce_datetime(row.expires_at)
if expires_at is not None and expires_at < current_time:
await session.commit()
return None
# Conditional UPDATE so two concurrent workers cannot both consume
# the same binding code: only the writer that flips consumed_at
# from NULL wins.
result = await session.execute(
update(ChannelOAuthStateRow)
.where(
ChannelOAuthStateRow.state_hash == state_hash,
ChannelOAuthStateRow.consumed_at.is_(None),
)
.values(consumed_at=current_time)
)
await session.commit()
if result.rowcount != 1:
return None
return {
"owner_user_id": row.owner_user_id,
"provider": row.provider,
"requested_scopes": row.requested_scopes_json or [],
"metadata": row.metadata_json or {},
"redirect_after": row.redirect_after,
}
async def find_connection_by_external_identity(
self,
*,
provider: str,
external_account_id: str,
workspace_id: str | None = None,
) -> dict[str, Any] | None:
async with self.session_factory() as session:
result = await session.execute(
select(ChannelConnectionRow)
.where(
ChannelConnectionRow.provider == provider,
ChannelConnectionRow.external_account_id == self._normalize_optional_identity(external_account_id),
ChannelConnectionRow.workspace_id == self._normalize_optional_identity(workspace_id),
ChannelConnectionRow.status == "connected",
)
.order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc())
.limit(1)
)
row = result.scalar_one_or_none()
return self._connection_to_dict(row) if row is not None else None
async def set_thread_id(
self,
*,
connection_id: str,
owner_user_id: str,
provider: str,
external_conversation_id: str,
thread_id: str,
external_topic_id: str | None = None,
) -> None:
topic_id = external_topic_id or ""
async with self.session_factory() as session:
stmt = select(ChannelConversationRow).where(
ChannelConversationRow.connection_id == connection_id,
ChannelConversationRow.external_conversation_id == external_conversation_id,
ChannelConversationRow.external_topic_id == topic_id,
)
row = (await session.execute(stmt)).scalar_one_or_none()
if row is None:
row = ChannelConversationRow(
id=self._new_id(),
connection_id=connection_id,
owner_user_id=owner_user_id,
provider=provider,
external_conversation_id=external_conversation_id,
external_topic_id=topic_id,
thread_id=thread_id,
)
session.add(row)
else:
row.thread_id = thread_id
row.owner_user_id = owner_user_id
row.provider = provider
await session.commit()
async def get_thread_id(
self,
connection_id: str,
external_conversation_id: str,
external_topic_id: str | None = None,
) -> str | None:
async with self.session_factory() as session:
stmt = select(ChannelConversationRow.thread_id).where(
ChannelConversationRow.connection_id == connection_id,
ChannelConversationRow.external_conversation_id == external_conversation_id,
ChannelConversationRow.external_topic_id == (external_topic_id or ""),
)
return (await session.execute(stmt)).scalar_one_or_none()