mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 01:45:58 +00:00
Add user-owned IM channel connections
This commit is contained in:
@@ -52,6 +52,56 @@ def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str,
|
||||
return default
|
||||
|
||||
|
||||
def _merge_channel_connection_runtime_config(channels_config: dict[str, Any], app_config: AppConfig) -> None:
|
||||
connection_config = getattr(app_config, "channel_connections", None)
|
||||
if connection_config is None or not getattr(connection_config, "enabled", False):
|
||||
return
|
||||
|
||||
telegram = getattr(connection_config, "telegram", None)
|
||||
if telegram is not None and getattr(telegram, "enabled", False) and getattr(telegram, "configured", False):
|
||||
telegram_config = dict(channels_config.get("telegram", {})) if isinstance(channels_config.get("telegram"), dict) else {}
|
||||
telegram_config.setdefault("enabled", True)
|
||||
telegram_config.setdefault("bot_token", telegram.bot_token)
|
||||
channels_config["telegram"] = telegram_config
|
||||
|
||||
slack = getattr(connection_config, "slack", None)
|
||||
if slack is not None and getattr(slack, "enabled", False) and getattr(slack, "configured", False):
|
||||
slack_config = dict(channels_config.get("slack", {})) if isinstance(channels_config.get("slack"), dict) else {}
|
||||
slack_config.setdefault("enabled", True)
|
||||
slack_config.setdefault("event_delivery", slack.event_delivery)
|
||||
slack_config.setdefault("signing_secret", slack.signing_secret)
|
||||
channels_config["slack"] = slack_config
|
||||
|
||||
discord = getattr(connection_config, "discord", None)
|
||||
if discord is not None and getattr(discord, "enabled", False) and getattr(discord, "configured", False):
|
||||
discord_config = dict(channels_config.get("discord", {})) if isinstance(channels_config.get("discord"), dict) else {}
|
||||
discord_config.setdefault("enabled", True)
|
||||
discord_config.setdefault("bot_token", discord.bot_token)
|
||||
channels_config["discord"] = discord_config
|
||||
|
||||
|
||||
def _make_connection_repo(app_config: AppConfig):
|
||||
connection_config = getattr(app_config, "channel_connections", None)
|
||||
if connection_config is None or not getattr(connection_config, "enabled", False):
|
||||
return None
|
||||
encryption_key = getattr(connection_config, "encryption_key", "")
|
||||
if not encryption_key:
|
||||
return None
|
||||
|
||||
try:
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
except Exception:
|
||||
logger.exception("Failed to import channel connection repository")
|
||||
return None
|
||||
|
||||
session_factory = get_session_factory()
|
||||
if session_factory is None:
|
||||
logger.warning("Channel connections are enabled but database persistence is not available")
|
||||
return None
|
||||
return ChannelConnectionRepository(session_factory, cipher=ChannelCredentialCipher.from_key(encryption_key))
|
||||
|
||||
|
||||
class ChannelService:
|
||||
"""Manages the lifecycle of all configured IM channels.
|
||||
|
||||
@@ -59,9 +109,10 @@ class ChannelService:
|
||||
instantiates enabled channels, and starts the ChannelManager dispatcher.
|
||||
"""
|
||||
|
||||
def __init__(self, channels_config: dict[str, Any] | None = None) -> None:
|
||||
def __init__(self, channels_config: dict[str, Any] | None = None, *, connection_repo: Any | None = None) -> None:
|
||||
self.bus = MessageBus()
|
||||
self.store = ChannelStore()
|
||||
self._connection_repo = connection_repo
|
||||
config = dict(channels_config or {})
|
||||
langgraph_url = _resolve_service_url(config, "langgraph_url", _CHANNELS_LANGGRAPH_URL_ENV, DEFAULT_LANGGRAPH_URL)
|
||||
gateway_url = _resolve_service_url(config, "gateway_url", _CHANNELS_GATEWAY_URL_ENV, DEFAULT_GATEWAY_URL)
|
||||
@@ -74,6 +125,7 @@ class ChannelService:
|
||||
gateway_url=gateway_url,
|
||||
default_session=default_session if isinstance(default_session, dict) else None,
|
||||
channel_sessions=channel_sessions,
|
||||
connection_repo=connection_repo,
|
||||
)
|
||||
self._channels: dict[str, Any] = {} # name -> Channel instance
|
||||
self._config = config
|
||||
@@ -90,8 +142,9 @@ class ChannelService:
|
||||
# extra fields are allowed by AppConfig (extra="allow")
|
||||
extra = app_config.model_extra or {}
|
||||
if "channels" in extra:
|
||||
channels_config = extra["channels"]
|
||||
return cls(channels_config=channels_config)
|
||||
channels_config = dict(extra["channels"] or {})
|
||||
_merge_channel_connection_runtime_config(channels_config, app_config)
|
||||
return cls(channels_config=channels_config, connection_repo=_make_connection_repo(app_config))
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the manager and all enabled channels."""
|
||||
@@ -169,6 +222,8 @@ class ChannelService:
|
||||
try:
|
||||
config = dict(config)
|
||||
config["channel_store"] = self.store
|
||||
if self._connection_repo is not None:
|
||||
config["connection_repo"] = self._connection_repo
|
||||
channel = channel_cls(bus=self.bus, config=config)
|
||||
self._channels[name] = channel
|
||||
await channel.start()
|
||||
|
||||
Reference in New Issue
Block a user