mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 09:55:59 +00:00
Align IM connections with local channels
This commit is contained in:
@@ -2,88 +2,48 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
ChannelConnectionMode = Literal["local", "private", "public"]
|
||||
TelegramDeliveryMode = Literal["polling", "webhook"]
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SlackChannelConnectionConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
client_id: str = ""
|
||||
client_secret: str = ""
|
||||
signing_secret: str = ""
|
||||
scopes: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"app_mentions:read",
|
||||
"chat:write",
|
||||
"channels:history",
|
||||
"channels:read",
|
||||
]
|
||||
)
|
||||
event_delivery: str = "http"
|
||||
|
||||
@property
|
||||
def configured(self) -> bool:
|
||||
return bool(self.client_id and self.client_secret and self.signing_secret)
|
||||
return True
|
||||
|
||||
|
||||
class TelegramChannelConnectionConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
bot_token: str = ""
|
||||
bot_username: str = ""
|
||||
delivery: TelegramDeliveryMode = "polling"
|
||||
webhook_secret: str = ""
|
||||
oidc_client_id: str = ""
|
||||
oidc_client_secret: str = ""
|
||||
|
||||
@property
|
||||
def configured(self) -> bool:
|
||||
if self.delivery == "webhook":
|
||||
return bool(self.bot_token and self.bot_username and self.webhook_secret)
|
||||
return bool(self.bot_token and self.bot_username)
|
||||
return bool(self.bot_username)
|
||||
|
||||
|
||||
class DiscordChannelConnectionConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
client_id: str = ""
|
||||
client_secret: str = ""
|
||||
bot_token: str = ""
|
||||
permissions: str = ""
|
||||
require_message_content_intent: bool = True
|
||||
|
||||
@property
|
||||
def configured(self) -> bool:
|
||||
return bool(self.client_id and self.client_secret and self.bot_token)
|
||||
return True
|
||||
|
||||
|
||||
class ChannelConnectionsConfig(BaseModel):
|
||||
"""Top-level config for browser-connectable IM channels."""
|
||||
|
||||
enabled: bool = False
|
||||
mode: ChannelConnectionMode = "local"
|
||||
public_base_url: str = ""
|
||||
encryption_key: str = ""
|
||||
slack: SlackChannelConnectionConfig = Field(default_factory=SlackChannelConnectionConfig)
|
||||
telegram: TelegramChannelConnectionConfig = Field(default_factory=TelegramChannelConnectionConfig)
|
||||
discord: DiscordChannelConnectionConfig = Field(default_factory=DiscordChannelConnectionConfig)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _require_shared_config_when_enabled(self) -> ChannelConnectionsConfig:
|
||||
missing: list[str] = []
|
||||
if self.enabled and self.mode == "public" and not self.public_base_url:
|
||||
missing.append("public_base_url is required when channel_connections.mode is public")
|
||||
if missing:
|
||||
raise ValueError("; ".join(missing))
|
||||
return self
|
||||
|
||||
def provider_status(self, provider: str) -> dict[str, bool]:
|
||||
config = getattr(self, provider, None)
|
||||
if config is None:
|
||||
return {"enabled": False, "configured": False}
|
||||
enabled = bool(config.enabled)
|
||||
return {
|
||||
"enabled": bool(config.enabled),
|
||||
"configured": bool(config.configured),
|
||||
"enabled": enabled,
|
||||
"configured": enabled and bool(config.configured),
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ from deerflow.persistence.channel_connections.model import (
|
||||
ChannelConversationRow,
|
||||
ChannelCredentialRow,
|
||||
ChannelOAuthStateRow,
|
||||
ChannelWebhookDeliveryRow,
|
||||
)
|
||||
from deerflow.persistence.channel_connections.sql import (
|
||||
ChannelConnectionRepository,
|
||||
@@ -19,5 +18,4 @@ __all__ = [
|
||||
"ChannelCredentialCipher",
|
||||
"ChannelCredentialRow",
|
||||
"ChannelOAuthStateRow",
|
||||
"ChannelWebhookDeliveryRow",
|
||||
]
|
||||
|
||||
@@ -109,13 +109,3 @@ class ChannelConversationRow(Base):
|
||||
name="uq_channel_conversation_connection_external",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ChannelWebhookDeliveryRow(Base):
|
||||
__tablename__ = "channel_webhook_deliveries"
|
||||
|
||||
provider: Mapped[str] = mapped_column(String(32), primary_key=True)
|
||||
delivery_id: Mapped[str] = mapped_column(String(128), primary_key=True)
|
||||
payload_sha256: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
event_type: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
processed_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now)
|
||||
|
||||
@@ -18,7 +18,6 @@ from deerflow.persistence.channel_connections.model import (
|
||||
ChannelConversationRow,
|
||||
ChannelCredentialRow,
|
||||
ChannelOAuthStateRow,
|
||||
ChannelWebhookDeliveryRow,
|
||||
)
|
||||
from deerflow.utils.time import coerce_iso
|
||||
|
||||
@@ -345,30 +344,3 @@ class ChannelConnectionRepository:
|
||||
ChannelConversationRow.external_topic_id == (external_topic_id or ""),
|
||||
)
|
||||
return (await session.execute(stmt)).scalar_one_or_none()
|
||||
|
||||
async def record_webhook_delivery(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
delivery_id: str,
|
||||
payload_sha256: str,
|
||||
event_type: str | None = None,
|
||||
) -> bool:
|
||||
async with self.session_factory() as session:
|
||||
existing = await session.get(
|
||||
ChannelWebhookDeliveryRow,
|
||||
{"provider": provider, "delivery_id": delivery_id},
|
||||
)
|
||||
if existing is not None:
|
||||
return False
|
||||
|
||||
session.add(
|
||||
ChannelWebhookDeliveryRow(
|
||||
provider=provider,
|
||||
delivery_id=delivery_id,
|
||||
payload_sha256=payload_sha256,
|
||||
event_type=event_type,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
@@ -19,7 +19,6 @@ from deerflow.persistence.channel_connections.model import (
|
||||
ChannelConversationRow,
|
||||
ChannelCredentialRow,
|
||||
ChannelOAuthStateRow,
|
||||
ChannelWebhookDeliveryRow,
|
||||
)
|
||||
from deerflow.persistence.feedback.model import FeedbackRow
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
@@ -32,7 +31,6 @@ __all__ = [
|
||||
"ChannelConversationRow",
|
||||
"ChannelCredentialRow",
|
||||
"ChannelOAuthStateRow",
|
||||
"ChannelWebhookDeliveryRow",
|
||||
"FeedbackRow",
|
||||
"RunEventRow",
|
||||
"RunRow",
|
||||
|
||||
Reference in New Issue
Block a user