mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 09:55:59 +00:00
Support local IM channel connections
This commit is contained in:
@@ -29,6 +29,8 @@ class ChannelProviderResponse(BaseModel):
|
||||
display_name: str
|
||||
enabled: bool
|
||||
configured: bool
|
||||
connectable: bool
|
||||
unavailable_reason: str | None = None
|
||||
auth_mode: str
|
||||
connection_status: str
|
||||
|
||||
@@ -93,10 +95,9 @@ def _get_repository(request: Request, config: ChannelConnectionsConfig) -> Chann
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
raise HTTPException(status_code=503, detail="Channel connection persistence is not available")
|
||||
if not config.encryption_key:
|
||||
raise HTTPException(status_code=503, detail="Channel connection encryption key is not configured")
|
||||
|
||||
repo = ChannelConnectionRepository(sf, cipher=ChannelCredentialCipher.from_key(config.encryption_key))
|
||||
cipher = ChannelCredentialCipher.from_key(config.encryption_key) if config.encryption_key else None
|
||||
repo = ChannelConnectionRepository(sf, cipher=cipher)
|
||||
request.app.state.channel_connection_repo = repo
|
||||
return repo
|
||||
|
||||
@@ -108,6 +109,43 @@ def _provider_config(config: ChannelConnectionsConfig, provider: str):
|
||||
return provider_config
|
||||
|
||||
|
||||
def _provider_unavailable_reason(config: ChannelConnectionsConfig, provider: str) -> str | None:
|
||||
provider_config = _provider_config(config, provider)
|
||||
if not provider_config.enabled or not provider_config.configured:
|
||||
return None
|
||||
|
||||
if provider == "telegram" and getattr(provider_config, "delivery", "polling") == "webhook":
|
||||
if not provider_config.webhook_secret:
|
||||
return "Telegram webhook delivery requires channel_connections.telegram.webhook_secret"
|
||||
if not config.public_base_url:
|
||||
return "Telegram webhook delivery requires channel_connections.public_base_url; use polling for local/private deployments"
|
||||
|
||||
if provider == "slack" and getattr(provider_config, "event_delivery", "http") == "http" and not config.public_base_url:
|
||||
return "Slack HTTP Events require channel_connections.public_base_url; use a public URL/tunnel or Slack Socket Mode for private deployments"
|
||||
|
||||
if provider in {"slack", "discord"} and not config.encryption_key:
|
||||
display_name = _PROVIDER_META[provider]["display_name"]
|
||||
return f"{display_name} connections require channel_connections.encryption_key to store OAuth credentials"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _require_provider_connectable(config: ChannelConnectionsConfig, provider: str) -> None:
|
||||
reason = _provider_unavailable_reason(config, provider)
|
||||
if reason:
|
||||
raise HTTPException(status_code=400, detail=reason)
|
||||
|
||||
|
||||
def _callback_base_url(config: ChannelConnectionsConfig, request: Request) -> str:
|
||||
if config.public_base_url:
|
||||
return config.public_base_url.rstrip("/")
|
||||
return str(request.base_url).rstrip("/")
|
||||
|
||||
|
||||
def _callback_redirect_uri(config: ChannelConnectionsConfig, request: Request, provider: str) -> str:
|
||||
return f"{_callback_base_url(config, request)}/api/channels/{provider}/callback"
|
||||
|
||||
|
||||
async def _create_state(
|
||||
repo: ChannelConnectionRepository,
|
||||
*,
|
||||
@@ -128,12 +166,12 @@ async def _create_state(
|
||||
return state
|
||||
|
||||
|
||||
def _build_connect_url(config: ChannelConnectionsConfig, provider: str, state: str) -> str:
|
||||
def _build_connect_url(config: ChannelConnectionsConfig, request: Request, provider: str, state: str) -> str:
|
||||
provider_config = _provider_config(config, provider)
|
||||
if provider == "telegram":
|
||||
return f"https://t.me/{provider_config.bot_username}?start={state}"
|
||||
|
||||
redirect_uri = f"{config.public_base_url.rstrip('/')}/api/channels/{provider}/callback"
|
||||
redirect_uri = _callback_redirect_uri(config, request, provider)
|
||||
if provider == "slack":
|
||||
query = urlencode(
|
||||
{
|
||||
@@ -242,7 +280,13 @@ async def _publish_slack_event(
|
||||
@router.get("/providers", response_model=ChannelProvidersResponse)
|
||||
async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
||||
config = _get_channel_connections_config(request)
|
||||
repo = _get_repository(request, config) if config.enabled and config.encryption_key else None
|
||||
repo = None
|
||||
if config.enabled:
|
||||
try:
|
||||
repo = _get_repository(request, config)
|
||||
except HTTPException as exc:
|
||||
if exc.status_code != 503:
|
||||
raise
|
||||
owner_user_id = _get_user_id(request)
|
||||
connections = await repo.list_connections(owner_user_id) if repo is not None else []
|
||||
by_provider = {item["provider"]: item for item in connections}
|
||||
@@ -251,12 +295,15 @@ async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
||||
for provider, meta in _PROVIDER_META.items():
|
||||
status = config.provider_status(provider)
|
||||
connection = by_provider.get(provider)
|
||||
unavailable_reason = _provider_unavailable_reason(config, provider)
|
||||
providers.append(
|
||||
ChannelProviderResponse(
|
||||
provider=provider,
|
||||
display_name=meta["display_name"],
|
||||
enabled=status["enabled"],
|
||||
configured=status["configured"],
|
||||
connectable=status["enabled"] and status["configured"] and unavailable_reason is None,
|
||||
unavailable_reason=unavailable_reason,
|
||||
auth_mode=meta["auth_mode"],
|
||||
connection_status=connection["status"] if connection else "not_connected",
|
||||
)
|
||||
@@ -307,7 +354,7 @@ async def slack_oauth_callback(request: Request, code: str | None = None, state:
|
||||
if state_data is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired OAuth state")
|
||||
|
||||
redirect_uri = f"{config.public_base_url.rstrip('/')}/api/channels/slack/callback"
|
||||
redirect_uri = _callback_redirect_uri(config, request, "slack")
|
||||
install = await slack_connect.exchange_slack_oauth_code(
|
||||
client_id=provider_config.client_id,
|
||||
client_secret=provider_config.client_secret,
|
||||
@@ -351,7 +398,7 @@ async def discord_oauth_callback(request: Request, code: str | None = None, stat
|
||||
if state_data is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired OAuth state")
|
||||
|
||||
redirect_uri = f"{config.public_base_url.rstrip('/')}/api/channels/discord/callback"
|
||||
redirect_uri = _callback_redirect_uri(config, request, "discord")
|
||||
identity = await discord_connect.complete_discord_oauth(
|
||||
client_id=provider_config.client_id,
|
||||
client_secret=provider_config.client_secret,
|
||||
@@ -471,6 +518,7 @@ async def connect_channel_provider(provider: str, request: Request) -> ChannelCo
|
||||
provider_config = _provider_config(config, provider)
|
||||
if not provider_config.enabled or not provider_config.configured:
|
||||
raise HTTPException(status_code=400, detail="Channel provider is not configured")
|
||||
_require_provider_connectable(config, provider)
|
||||
|
||||
repo = _get_repository(request, config)
|
||||
state = await _create_state(
|
||||
@@ -482,6 +530,6 @@ async def connect_channel_provider(provider: str, request: Request) -> ChannelCo
|
||||
return ChannelConnectResponse(
|
||||
provider=provider,
|
||||
mode=_PROVIDER_META[provider]["auth_mode"],
|
||||
url=_build_connect_url(config, provider, state),
|
||||
url=_build_connect_url(config, request, provider, state),
|
||||
expires_in=_STATE_TTL_SECONDS,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user