mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 01:45:58 +00:00
Support local IM channel connections
This commit is contained in:
@@ -2,8 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
ChannelConnectionMode = Literal["local", "private", "public"]
|
||||
TelegramDeliveryMode = Literal["polling", "webhook"]
|
||||
|
||||
|
||||
class SlackChannelConnectionConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
@@ -29,13 +34,16 @@ class TelegramChannelConnectionConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
bot_token: str = ""
|
||||
bot_username: str = ""
|
||||
delivery: TelegramDeliveryMode = "polling"
|
||||
webhook_secret: str = ""
|
||||
oidc_client_id: str = ""
|
||||
oidc_client_secret: str = ""
|
||||
|
||||
@property
|
||||
def configured(self) -> bool:
|
||||
return bool(self.bot_token and self.bot_username and self.webhook_secret)
|
||||
if self.delivery == "webhook":
|
||||
return bool(self.bot_token and self.bot_username and self.webhook_secret)
|
||||
return bool(self.bot_token and self.bot_username)
|
||||
|
||||
|
||||
class DiscordChannelConnectionConfig(BaseModel):
|
||||
@@ -55,6 +63,7 @@ class ChannelConnectionsConfig(BaseModel):
|
||||
"""Top-level config for browser-connectable IM channels."""
|
||||
|
||||
enabled: bool = False
|
||||
mode: ChannelConnectionMode = "local"
|
||||
public_base_url: str = ""
|
||||
encryption_key: str = ""
|
||||
slack: SlackChannelConnectionConfig = Field(default_factory=SlackChannelConnectionConfig)
|
||||
@@ -64,10 +73,8 @@ class ChannelConnectionsConfig(BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def _require_shared_config_when_enabled(self) -> ChannelConnectionsConfig:
|
||||
missing: list[str] = []
|
||||
if self.enabled and not self.public_base_url:
|
||||
missing.append("public_base_url is required when channel_connections.enabled is true")
|
||||
if self.enabled and not self.encryption_key:
|
||||
missing.append("encryption_key is required when channel_connections.enabled is true")
|
||||
if self.enabled and self.mode == "public" and not self.public_base_url:
|
||||
missing.append("public_base_url is required when channel_connections.mode is public")
|
||||
if missing:
|
||||
raise ValueError("; ".join(missing))
|
||||
return self
|
||||
|
||||
@@ -53,7 +53,7 @@ class ChannelConnectionRepository:
|
||||
self,
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
*,
|
||||
cipher: ChannelCredentialCipher,
|
||||
cipher: ChannelCredentialCipher | None = None,
|
||||
) -> None:
|
||||
self.session_factory = session_factory
|
||||
self._cipher = cipher
|
||||
@@ -77,6 +77,13 @@ class ChannelConnectionRepository:
|
||||
return value
|
||||
return value.replace(tzinfo=UTC)
|
||||
|
||||
def _encrypt_optional_secret(self, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if self._cipher is None:
|
||||
raise RuntimeError("channel connection encryption key is required")
|
||||
return self._cipher.encrypt_text(value)
|
||||
|
||||
@staticmethod
|
||||
def _connection_to_dict(row: ChannelConnectionRow) -> dict[str, Any]:
|
||||
data = row.to_dict()
|
||||
@@ -166,6 +173,8 @@ class ChannelConnectionRepository:
|
||||
refresh_expires_at: datetime | None = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
if self._cipher is None:
|
||||
raise RuntimeError("channel connection encryption key is required")
|
||||
async with self.session_factory() as session:
|
||||
row = await session.get(ChannelCredentialRow, connection_id)
|
||||
if row is None:
|
||||
@@ -181,6 +190,8 @@ class ChannelConnectionRepository:
|
||||
await session.commit()
|
||||
|
||||
async def get_credentials(self, connection_id: str) -> dict[str, Any] | None:
|
||||
if self._cipher is None:
|
||||
return None
|
||||
async with self.session_factory() as session:
|
||||
row = await session.get(ChannelCredentialRow, connection_id)
|
||||
if row is None:
|
||||
@@ -217,7 +228,7 @@ class ChannelConnectionRepository:
|
||||
state_hash=self.hash_state(state),
|
||||
owner_user_id=owner_user_id,
|
||||
provider=provider,
|
||||
code_verifier_encrypted=self._cipher.encrypt_text(code_verifier),
|
||||
code_verifier_encrypted=self._encrypt_optional_secret(code_verifier),
|
||||
nonce_hash=nonce_hash,
|
||||
redirect_after=redirect_after,
|
||||
requested_scopes_json=list(requested_scopes or []),
|
||||
|
||||
Reference in New Issue
Block a user