mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-18 13:46:02 +00:00
8c0830aea1
* fix(channels): add operational guardrails * make format * fix(channels): converge with #3582 to avoid merge-order conflicts Drop this PR's DingTalk INFO-log redaction and hand it to #3582, which already restructures that handler and will redact the same log there. This PR no longer touches dingtalk.py, so the two PRs can merge to main in any order without a conflict. For WeChat, drop the contested thread_ts priority reorder (review #3) and keep only what inbound dedupe needs: a server-stable message_id in the inbound metadata (message_id/msg_id, no client_id per review #6). This is a single added line inside the metadata dict, a region #3582 never touches, so it auto-merges regardless of order. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * fix(channels): address three correctness review findings 1. Connect-code cap was racy (willem #1): _create_state ran delete-expired, count, and insert as three separate transactions, so concurrent connect POSTs from one owner could each see count < cap and all insert past it. Add ChannelConnectionRepository.create_oauth_state_within_cap which does delete+count+insert in a single transaction serialized per (owner, provider) — Postgres via pg_advisory_xact_lock, SQLite via the write lock the leading DELETE takes — and have the router use it. 2. Inbound dedupe key fell back to "" workspace (willem #3): two workspaces delivering without team/guild/aibotid would collapse to the same key and dedupe each other's messages. _inbound_dedupe_key now fails closed (returns None) when no workspace identifier is present. 3. Dedupe key was recorded on receipt and never released on failure (ShenAC #1): a transient error (DB blip, Gateway 503) left the key in place for the full TTL, so a provider redelivery of the same message_id — exactly the retry dedupe should absorb — was silently dropped. _handle_message now releases the key in the unexpected-exception branch so redelivery can recover, while keeping record-on-receipt so retries during handling are still deduped. Tests: repo cap enforcement incl. concurrent-issuance non-leak; dedupe fail-closed; dedupe key release-on-failure redelivery recovery. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * fix(channels): address cleanup/efficiency and test review findings Efficiency / cleanup: - Dedupe key set drops client-generated ids (client_msg_id, client_id); keep only server-stable event_id/message_id/msg_id, which a provider's own redelivery preserves (ShenAC #6). Every provider already emits message_id. - TTL/overflow pruning of _recent_inbound_events is now O(k): switch to an OrderedDict and popitem(last=False) from the front instead of scanning all 4096 entries on every inbound (willem #4). - Log "received inbound" only after the dedupe check so a provider retrying N times no longer logs N accepts; document that manager dedupe covers the agent run/final answer, not provider ack side-effects (willem #5, ShenAC #2). - Slack drops the redundant `team_id or event.get("team")` fallback the caller already resolved (willem #6). - create_oauth_state_within_cap prunes only this owner/provider's expired codes instead of a global DELETE on every connect POST; global cleanup still runs on consume_oauth_state (willem #7). Tests: - Dedupe test uses tmp_path instead of a leaked mkdtemp, uses distinct objects per publish, and adds a negative control: a different message_id is still processed, catching over-dedupe regressions (willem #8, ShenAC #4). - Slack HTTP-mode rejection test supplies app_token so the missing-token early return can't mask the guard, giving the state assertions teeth (ShenAC #3). - count_oauth_states test pins that the active row survives, not just the count (ShenAC #5). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * make format --------- Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
518 lines
21 KiB
Python
518 lines
21 KiB
Python
"""SQL repository for user-owned IM channel connections."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from datetime import UTC, datetime
|
|
from typing import Any
|
|
|
|
from cryptography.fernet import Fernet, InvalidToken
|
|
from sqlalchemy import delete, func, select, text, update
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
|
|
from deerflow.persistence.channel_connections.model import (
|
|
ChannelConnectionRow,
|
|
ChannelConversationRow,
|
|
ChannelCredentialRow,
|
|
ChannelOAuthStateRow,
|
|
)
|
|
from deerflow.utils.time import coerce_iso
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ChannelCredentialCipher:
|
|
"""Encrypts provider credentials before they are persisted."""
|
|
|
|
def __init__(self, fernet: Fernet) -> None:
|
|
self._fernet = fernet
|
|
|
|
@classmethod
|
|
def from_key(cls, key: str) -> ChannelCredentialCipher:
|
|
digest = hashlib.sha256(key.encode("utf-8")).digest()
|
|
return cls(Fernet(base64.urlsafe_b64encode(digest)))
|
|
|
|
def encrypt_text(self, value: str | None) -> str | None:
|
|
if value is None:
|
|
return None
|
|
return "fernet:v1:" + self._fernet.encrypt(value.encode("utf-8")).decode("ascii")
|
|
|
|
def decrypt_text(self, value: str | None) -> str | None:
|
|
if value is None:
|
|
return None
|
|
token = value.removeprefix("fernet:v1:")
|
|
return self._fernet.decrypt(token.encode("ascii")).decode("utf-8")
|
|
|
|
|
|
class ChannelConnectionRepository:
|
|
"""Persistence facade for channel connections, credentials, and conversations."""
|
|
|
|
def __init__(
|
|
self,
|
|
session_factory: async_sessionmaker[AsyncSession],
|
|
*,
|
|
cipher: ChannelCredentialCipher | None = None,
|
|
) -> None:
|
|
self.session_factory = session_factory
|
|
self._cipher = cipher
|
|
|
|
async def close(self) -> None:
|
|
from deerflow.persistence.engine import close_engine
|
|
|
|
await close_engine()
|
|
|
|
@staticmethod
|
|
def _new_id() -> str:
|
|
return uuid.uuid4().hex
|
|
|
|
@staticmethod
|
|
def _normalize_optional_identity(value: str | None) -> str:
|
|
return value or ""
|
|
|
|
@staticmethod
|
|
def _coerce_datetime(value: datetime | None) -> datetime | None:
|
|
if value is None or value.tzinfo is not None:
|
|
return value
|
|
return value.replace(tzinfo=UTC)
|
|
|
|
def _encrypt_optional_secret(self, value: str | None) -> str | None:
|
|
if value is None:
|
|
return None
|
|
if self._cipher is None:
|
|
raise RuntimeError("channel connection encryption key is required")
|
|
return self._cipher.encrypt_text(value)
|
|
|
|
@staticmethod
|
|
def _connection_to_dict(row: ChannelConnectionRow) -> dict[str, Any]:
|
|
data = row.to_dict()
|
|
data["external_account_id"] = data["external_account_id"] or None
|
|
data["workspace_id"] = data["workspace_id"] or None
|
|
data["scopes"] = data.pop("scopes_json") or []
|
|
data["capabilities"] = data.pop("capabilities_json") or {}
|
|
data["metadata"] = data.pop("metadata_json") or {}
|
|
for key in ("created_at", "updated_at", "last_seen_at", "last_error_at"):
|
|
value = data.get(key)
|
|
if isinstance(value, datetime):
|
|
data[key] = coerce_iso(value)
|
|
return data
|
|
|
|
async def upsert_connection(
|
|
self,
|
|
*,
|
|
owner_user_id: str,
|
|
provider: str,
|
|
external_account_id: str | None = None,
|
|
external_account_name: str | None = None,
|
|
workspace_id: str | None = None,
|
|
workspace_name: str | None = None,
|
|
bot_user_id: str | None = None,
|
|
scopes: list[str] | None = None,
|
|
capabilities: dict[str, Any] | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
status: str = "connected",
|
|
) -> dict[str, Any]:
|
|
external_account_id_value = self._normalize_optional_identity(external_account_id)
|
|
workspace_id_value = self._normalize_optional_identity(workspace_id)
|
|
|
|
def _apply(row: ChannelConnectionRow) -> None:
|
|
row.status = status
|
|
row.external_account_name = external_account_name
|
|
row.workspace_name = workspace_name
|
|
row.bot_user_id = bot_user_id
|
|
row.scopes_json = list(scopes or [])
|
|
row.capabilities_json = dict(capabilities or {})
|
|
row.metadata_json = dict(metadata or {})
|
|
|
|
stmt = select(ChannelConnectionRow).where(
|
|
ChannelConnectionRow.owner_user_id == owner_user_id,
|
|
ChannelConnectionRow.provider == provider,
|
|
ChannelConnectionRow.external_account_id == external_account_id_value,
|
|
ChannelConnectionRow.workspace_id == workspace_id_value,
|
|
)
|
|
async with self.session_factory() as session:
|
|
row = (await session.execute(stmt)).scalar_one_or_none()
|
|
if row is None:
|
|
row = ChannelConnectionRow(
|
|
id=self._new_id(),
|
|
owner_user_id=owner_user_id,
|
|
provider=provider,
|
|
external_account_id=external_account_id_value,
|
|
workspace_id=workspace_id_value,
|
|
)
|
|
session.add(row)
|
|
|
|
_apply(row)
|
|
try:
|
|
await session.commit()
|
|
except IntegrityError:
|
|
# A concurrent writer inserted the same identity first; retry as
|
|
# an update of that row.
|
|
await session.rollback()
|
|
row = (await session.execute(stmt)).scalar_one()
|
|
_apply(row)
|
|
await session.commit()
|
|
await session.refresh(row)
|
|
return self._connection_to_dict(row)
|
|
|
|
async def list_connections(self, owner_user_id: str) -> list[dict[str, Any]]:
|
|
async with self.session_factory() as session:
|
|
result = await session.execute(select(ChannelConnectionRow).where(ChannelConnectionRow.owner_user_id == owner_user_id).order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc()))
|
|
return [self._connection_to_dict(row) for row in result.scalars()]
|
|
|
|
async def disconnect_connection(self, *, connection_id: str, owner_user_id: str) -> bool:
|
|
async with self.session_factory() as session:
|
|
row = await session.get(ChannelConnectionRow, connection_id)
|
|
if row is None or row.owner_user_id != owner_user_id:
|
|
return False
|
|
|
|
row.status = "revoked"
|
|
credential = await session.get(ChannelCredentialRow, connection_id)
|
|
if credential is not None:
|
|
await session.delete(credential)
|
|
await session.commit()
|
|
return True
|
|
|
|
async def disconnect_provider_connections(self, *, provider: str) -> int:
|
|
"""Revoke all active user connections for an instance-wide provider removal."""
|
|
async with self.session_factory() as session:
|
|
result = await session.execute(
|
|
select(ChannelConnectionRow.id).where(
|
|
ChannelConnectionRow.provider == provider,
|
|
ChannelConnectionRow.status != "revoked",
|
|
)
|
|
)
|
|
connection_ids = [row_id for row_id in result.scalars()]
|
|
if not connection_ids:
|
|
return 0
|
|
|
|
await session.execute(update(ChannelConnectionRow).where(ChannelConnectionRow.id.in_(connection_ids)).values(status="revoked"))
|
|
await session.execute(delete(ChannelCredentialRow).where(ChannelCredentialRow.connection_id.in_(connection_ids)))
|
|
await session.commit()
|
|
return len(connection_ids)
|
|
|
|
async def store_credentials(
|
|
self,
|
|
connection_id: str,
|
|
*,
|
|
access_token: str | None,
|
|
refresh_token: str | None = None,
|
|
token_type: str | None = None,
|
|
expires_at: datetime | None = None,
|
|
refresh_expires_at: datetime | None = None,
|
|
extra: dict[str, Any] | None = None,
|
|
) -> None:
|
|
if self._cipher is None:
|
|
raise RuntimeError("channel connection encryption key is required")
|
|
async with self.session_factory() as session:
|
|
row = await session.get(ChannelCredentialRow, connection_id)
|
|
if row is None:
|
|
row = ChannelCredentialRow(connection_id=connection_id)
|
|
session.add(row)
|
|
row.encrypted_access_token = self._cipher.encrypt_text(access_token)
|
|
row.encrypted_refresh_token = self._cipher.encrypt_text(refresh_token)
|
|
row.token_type = token_type
|
|
row.expires_at = expires_at
|
|
row.refresh_expires_at = refresh_expires_at
|
|
row.encrypted_extra_json = self._cipher.encrypt_text(json.dumps(extra or {}, ensure_ascii=False))
|
|
row.version = (row.version or 0) + 1
|
|
await session.commit()
|
|
|
|
async def get_credentials(self, connection_id: str) -> dict[str, Any] | None:
|
|
if self._cipher is None:
|
|
return None
|
|
async with self.session_factory() as session:
|
|
row = await session.get(ChannelCredentialRow, connection_id)
|
|
if row is None:
|
|
return None
|
|
try:
|
|
extra_raw = self._cipher.decrypt_text(row.encrypted_extra_json)
|
|
return {
|
|
"connection_id": row.connection_id,
|
|
"access_token": self._cipher.decrypt_text(row.encrypted_access_token),
|
|
"refresh_token": self._cipher.decrypt_text(row.encrypted_refresh_token),
|
|
"token_type": row.token_type,
|
|
"expires_at": self._coerce_datetime(row.expires_at),
|
|
"refresh_expires_at": self._coerce_datetime(row.refresh_expires_at),
|
|
"extra": json.loads(extra_raw) if extra_raw else {},
|
|
}
|
|
except (InvalidToken, UnicodeError, json.JSONDecodeError):
|
|
logger.warning(
|
|
"Unable to decrypt channel connection credentials; treating credentials as unavailable",
|
|
exc_info=True,
|
|
)
|
|
return None
|
|
|
|
@staticmethod
|
|
def hash_state(state: str) -> str:
|
|
return hashlib.sha256(state.encode("utf-8")).hexdigest()
|
|
|
|
async def create_oauth_state(
|
|
self,
|
|
*,
|
|
owner_user_id: str,
|
|
provider: str,
|
|
state: str,
|
|
expires_at: datetime,
|
|
code_verifier: str | None = None,
|
|
nonce_hash: str | None = None,
|
|
redirect_after: str | None = None,
|
|
requested_scopes: list[str] | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> None:
|
|
row = ChannelOAuthStateRow(
|
|
state_hash=self.hash_state(state),
|
|
owner_user_id=owner_user_id,
|
|
provider=provider,
|
|
code_verifier_encrypted=self._encrypt_optional_secret(code_verifier),
|
|
nonce_hash=nonce_hash,
|
|
redirect_after=redirect_after,
|
|
requested_scopes_json=list(requested_scopes or []),
|
|
metadata_json=dict(metadata or {}),
|
|
expires_at=expires_at,
|
|
)
|
|
async with self.session_factory() as session:
|
|
session.add(row)
|
|
await session.commit()
|
|
|
|
async def create_oauth_state_within_cap(
|
|
self,
|
|
*,
|
|
owner_user_id: str,
|
|
provider: str,
|
|
state: str,
|
|
expires_at: datetime,
|
|
max_pending: int,
|
|
now: datetime | None = None,
|
|
code_verifier: str | None = None,
|
|
nonce_hash: str | None = None,
|
|
redirect_after: str | None = None,
|
|
requested_scopes: list[str] | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> bool:
|
|
"""Atomically enforce the per-(owner, provider) pending cap, then insert.
|
|
|
|
delete-expired + count + insert run in a single transaction serialized
|
|
per (owner, provider), so concurrent connect requests cannot each
|
|
observe ``count < max_pending`` and all insert (which would leak past
|
|
the cap). PostgreSQL takes a transaction-scoped advisory lock; SQLite
|
|
serializes writers through the write lock the leading DELETE acquires.
|
|
|
|
Returns ``True`` when the row was inserted, ``False`` when the cap is
|
|
already reached.
|
|
"""
|
|
current_time = now or datetime.now(UTC)
|
|
async with self.session_factory() as session:
|
|
await self._serialize_oauth_owner_scope(session, owner_user_id, provider)
|
|
# Prune only this owner/provider's expired codes (the ones that affect
|
|
# this cap), not every user's — avoids a global DELETE on each connect
|
|
# POST. Issuing this write first also takes the SQLite database write
|
|
# lock so the count below cannot race a concurrent inserter between
|
|
# count and commit. Stale codes for other owners are pruned globally
|
|
# by consume_oauth_state / delete_expired_oauth_states.
|
|
await session.execute(
|
|
delete(ChannelOAuthStateRow).where(
|
|
ChannelOAuthStateRow.owner_user_id == owner_user_id,
|
|
ChannelOAuthStateRow.provider == provider,
|
|
ChannelOAuthStateRow.expires_at < current_time,
|
|
)
|
|
)
|
|
pending = await session.execute(
|
|
select(func.count())
|
|
.select_from(ChannelOAuthStateRow)
|
|
.where(
|
|
ChannelOAuthStateRow.owner_user_id == owner_user_id,
|
|
ChannelOAuthStateRow.provider == provider,
|
|
ChannelOAuthStateRow.consumed_at.is_(None),
|
|
ChannelOAuthStateRow.expires_at >= current_time,
|
|
)
|
|
)
|
|
if int(pending.scalar_one()) >= max_pending:
|
|
await session.rollback()
|
|
return False
|
|
session.add(
|
|
ChannelOAuthStateRow(
|
|
state_hash=self.hash_state(state),
|
|
owner_user_id=owner_user_id,
|
|
provider=provider,
|
|
code_verifier_encrypted=self._encrypt_optional_secret(code_verifier),
|
|
nonce_hash=nonce_hash,
|
|
redirect_after=redirect_after,
|
|
requested_scopes_json=list(requested_scopes or []),
|
|
metadata_json=dict(metadata or {}),
|
|
expires_at=expires_at,
|
|
)
|
|
)
|
|
await session.commit()
|
|
return True
|
|
|
|
async def _serialize_oauth_owner_scope(self, session: AsyncSession, owner_user_id: str, provider: str) -> None:
|
|
"""Serialize concurrent pending-cap transactions for one (owner, provider).
|
|
|
|
On PostgreSQL this takes a transaction-scoped advisory lock so concurrent
|
|
issuers run their count+insert one at a time. On SQLite the leading
|
|
DELETE in the caller's transaction already acquires the database write
|
|
lock, which serializes writers, so no extra lock is required.
|
|
"""
|
|
try:
|
|
dialect = session.bind.dialect.name if session.bind is not None else ""
|
|
except Exception:
|
|
dialect = ""
|
|
if dialect == "postgresql":
|
|
await session.execute(text("SELECT pg_advisory_xact_lock(:lock_key)"), {"lock_key": self._oauth_scope_lock_key(owner_user_id, provider)})
|
|
|
|
@staticmethod
|
|
def _oauth_scope_lock_key(owner_user_id: str, provider: str) -> int:
|
|
digest = hashlib.sha256(f"{owner_user_id}\x00{provider}".encode()).digest()
|
|
# 63-bit non-negative key for pg_advisory_xact_lock(bigint).
|
|
return int.from_bytes(digest[:8], "big") & 0x7FFFFFFFFFFFFFFF
|
|
|
|
async def delete_expired_oauth_states(self, *, now: datetime | None = None) -> int:
|
|
current_time = now or datetime.now(UTC)
|
|
async with self.session_factory() as session:
|
|
result = await session.execute(delete(ChannelOAuthStateRow).where(ChannelOAuthStateRow.expires_at < current_time))
|
|
await session.commit()
|
|
return int(result.rowcount or 0)
|
|
|
|
async def count_oauth_states(
|
|
self,
|
|
*,
|
|
owner_user_id: str,
|
|
provider: str,
|
|
active_only: bool = False,
|
|
now: datetime | None = None,
|
|
) -> int:
|
|
current_time = now or datetime.now(UTC)
|
|
conditions = [
|
|
ChannelOAuthStateRow.owner_user_id == owner_user_id,
|
|
ChannelOAuthStateRow.provider == provider,
|
|
]
|
|
if active_only:
|
|
conditions.extend(
|
|
[
|
|
ChannelOAuthStateRow.consumed_at.is_(None),
|
|
ChannelOAuthStateRow.expires_at >= current_time,
|
|
]
|
|
)
|
|
|
|
async with self.session_factory() as session:
|
|
result = await session.execute(select(func.count()).select_from(ChannelOAuthStateRow).where(*conditions))
|
|
return int(result.scalar_one())
|
|
|
|
async def consume_oauth_state(
|
|
self,
|
|
*,
|
|
provider: str,
|
|
state: str,
|
|
now: datetime | None = None,
|
|
) -> dict[str, Any] | None:
|
|
current_time = now or datetime.now(UTC)
|
|
state_hash = self.hash_state(state)
|
|
async with self.session_factory() as session:
|
|
await session.execute(delete(ChannelOAuthStateRow).where(ChannelOAuthStateRow.expires_at < current_time))
|
|
row = await session.get(ChannelOAuthStateRow, state_hash)
|
|
if row is None or row.provider != provider or row.consumed_at is not None:
|
|
await session.commit()
|
|
return None
|
|
expires_at = self._coerce_datetime(row.expires_at)
|
|
if expires_at is not None and expires_at < current_time:
|
|
await session.commit()
|
|
return None
|
|
|
|
# Conditional UPDATE so two concurrent workers cannot both consume
|
|
# the same binding code: only the writer that flips consumed_at
|
|
# from NULL wins.
|
|
result = await session.execute(
|
|
update(ChannelOAuthStateRow)
|
|
.where(
|
|
ChannelOAuthStateRow.state_hash == state_hash,
|
|
ChannelOAuthStateRow.consumed_at.is_(None),
|
|
)
|
|
.values(consumed_at=current_time)
|
|
)
|
|
await session.commit()
|
|
if result.rowcount != 1:
|
|
return None
|
|
return {
|
|
"owner_user_id": row.owner_user_id,
|
|
"provider": row.provider,
|
|
"requested_scopes": row.requested_scopes_json or [],
|
|
"metadata": row.metadata_json or {},
|
|
"redirect_after": row.redirect_after,
|
|
}
|
|
|
|
async def find_connection_by_external_identity(
|
|
self,
|
|
*,
|
|
provider: str,
|
|
external_account_id: str,
|
|
workspace_id: str | None = None,
|
|
) -> dict[str, Any] | None:
|
|
async with self.session_factory() as session:
|
|
result = await session.execute(
|
|
select(ChannelConnectionRow)
|
|
.where(
|
|
ChannelConnectionRow.provider == provider,
|
|
ChannelConnectionRow.external_account_id == self._normalize_optional_identity(external_account_id),
|
|
ChannelConnectionRow.workspace_id == self._normalize_optional_identity(workspace_id),
|
|
ChannelConnectionRow.status == "connected",
|
|
)
|
|
.order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc())
|
|
.limit(1)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
return self._connection_to_dict(row) if row is not None else None
|
|
|
|
async def set_thread_id(
|
|
self,
|
|
*,
|
|
connection_id: str,
|
|
owner_user_id: str,
|
|
provider: str,
|
|
external_conversation_id: str,
|
|
thread_id: str,
|
|
external_topic_id: str | None = None,
|
|
) -> None:
|
|
topic_id = external_topic_id or ""
|
|
async with self.session_factory() as session:
|
|
stmt = select(ChannelConversationRow).where(
|
|
ChannelConversationRow.connection_id == connection_id,
|
|
ChannelConversationRow.external_conversation_id == external_conversation_id,
|
|
ChannelConversationRow.external_topic_id == topic_id,
|
|
)
|
|
row = (await session.execute(stmt)).scalar_one_or_none()
|
|
if row is None:
|
|
row = ChannelConversationRow(
|
|
id=self._new_id(),
|
|
connection_id=connection_id,
|
|
owner_user_id=owner_user_id,
|
|
provider=provider,
|
|
external_conversation_id=external_conversation_id,
|
|
external_topic_id=topic_id,
|
|
thread_id=thread_id,
|
|
)
|
|
session.add(row)
|
|
else:
|
|
row.thread_id = thread_id
|
|
row.owner_user_id = owner_user_id
|
|
row.provider = provider
|
|
await session.commit()
|
|
|
|
async def get_thread_id(
|
|
self,
|
|
connection_id: str,
|
|
external_conversation_id: str,
|
|
external_topic_id: str | None = None,
|
|
) -> str | None:
|
|
async with self.session_factory() as session:
|
|
stmt = select(ChannelConversationRow.thread_id).where(
|
|
ChannelConversationRow.connection_id == connection_id,
|
|
ChannelConversationRow.external_conversation_id == external_conversation_id,
|
|
ChannelConversationRow.external_topic_id == (external_topic_id or ""),
|
|
)
|
|
return (await session.execute(stmt)).scalar_one_or_none()
|