Support local IM channel connections

This commit is contained in:
taohe
2026-06-10 21:59:33 +08:00
parent 9effa7be6d
commit 92c185b90d
16 changed files with 381 additions and 53 deletions
@@ -2,8 +2,13 @@
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, Field, model_validator
ChannelConnectionMode = Literal["local", "private", "public"]
TelegramDeliveryMode = Literal["polling", "webhook"]
class SlackChannelConnectionConfig(BaseModel):
enabled: bool = False
@@ -29,13 +34,16 @@ class TelegramChannelConnectionConfig(BaseModel):
enabled: bool = False
bot_token: str = ""
bot_username: str = ""
delivery: TelegramDeliveryMode = "polling"
webhook_secret: str = ""
oidc_client_id: str = ""
oidc_client_secret: str = ""
@property
def configured(self) -> bool:
return bool(self.bot_token and self.bot_username and self.webhook_secret)
if self.delivery == "webhook":
return bool(self.bot_token and self.bot_username and self.webhook_secret)
return bool(self.bot_token and self.bot_username)
class DiscordChannelConnectionConfig(BaseModel):
@@ -55,6 +63,7 @@ class ChannelConnectionsConfig(BaseModel):
"""Top-level config for browser-connectable IM channels."""
enabled: bool = False
mode: ChannelConnectionMode = "local"
public_base_url: str = ""
encryption_key: str = ""
slack: SlackChannelConnectionConfig = Field(default_factory=SlackChannelConnectionConfig)
@@ -64,10 +73,8 @@ class ChannelConnectionsConfig(BaseModel):
@model_validator(mode="after")
def _require_shared_config_when_enabled(self) -> ChannelConnectionsConfig:
missing: list[str] = []
if self.enabled and not self.public_base_url:
missing.append("public_base_url is required when channel_connections.enabled is true")
if self.enabled and not self.encryption_key:
missing.append("encryption_key is required when channel_connections.enabled is true")
if self.enabled and self.mode == "public" and not self.public_base_url:
missing.append("public_base_url is required when channel_connections.mode is public")
if missing:
raise ValueError("; ".join(missing))
return self
@@ -53,7 +53,7 @@ class ChannelConnectionRepository:
self,
session_factory: async_sessionmaker[AsyncSession],
*,
cipher: ChannelCredentialCipher,
cipher: ChannelCredentialCipher | None = None,
) -> None:
self.session_factory = session_factory
self._cipher = cipher
@@ -77,6 +77,13 @@ class ChannelConnectionRepository:
return value
return value.replace(tzinfo=UTC)
def _encrypt_optional_secret(self, value: str | None) -> str | None:
if value is None:
return None
if self._cipher is None:
raise RuntimeError("channel connection encryption key is required")
return self._cipher.encrypt_text(value)
@staticmethod
def _connection_to_dict(row: ChannelConnectionRow) -> dict[str, Any]:
data = row.to_dict()
@@ -166,6 +173,8 @@ class ChannelConnectionRepository:
refresh_expires_at: datetime | None = None,
extra: dict[str, Any] | None = None,
) -> None:
if self._cipher is None:
raise RuntimeError("channel connection encryption key is required")
async with self.session_factory() as session:
row = await session.get(ChannelCredentialRow, connection_id)
if row is None:
@@ -181,6 +190,8 @@ class ChannelConnectionRepository:
await session.commit()
async def get_credentials(self, connection_id: str) -> dict[str, Any] | None:
if self._cipher is None:
return None
async with self.session_factory() as session:
row = await session.get(ChannelCredentialRow, connection_id)
if row is None:
@@ -217,7 +228,7 @@ class ChannelConnectionRepository:
state_hash=self.hash_state(state),
owner_user_id=owner_user_id,
provider=provider,
code_verifier_encrypted=self._cipher.encrypt_text(code_verifier),
code_verifier_encrypted=self._encrypt_optional_secret(code_verifier),
nonce_hash=nonce_hash,
redirect_after=redirect_after,
requested_scopes_json=list(requested_scopes or []),