mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 18:05:58 +00:00
445 lines
16 KiB
Python
445 lines
16 KiB
Python
"""Browser-facing APIs for user-owned IM channel bindings."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import secrets
|
|
from datetime import UTC, datetime, timedelta
|
|
from typing import Any
|
|
|
|
from fastapi import APIRouter, HTTPException, Request, Response
|
|
from pydantic import BaseModel, Field
|
|
|
|
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
|
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
|
from deerflow.persistence.engine import get_session_factory
|
|
|
|
router = APIRouter(prefix="/api/channels", tags=["channel-connections"])
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_STATE_TTL_SECONDS = 600
|
|
|
|
|
|
class ChannelCredentialFieldResponse(BaseModel):
|
|
name: str
|
|
label: str
|
|
type: str = "text"
|
|
required: bool = True
|
|
|
|
|
|
class ChannelProviderResponse(BaseModel):
|
|
provider: str
|
|
display_name: str
|
|
enabled: bool
|
|
configured: bool
|
|
connectable: bool
|
|
unavailable_reason: str | None = None
|
|
auth_mode: str
|
|
connection_status: str
|
|
credential_fields: list[ChannelCredentialFieldResponse] = Field(default_factory=list)
|
|
|
|
|
|
class ChannelProvidersResponse(BaseModel):
|
|
enabled: bool
|
|
providers: list[ChannelProviderResponse]
|
|
|
|
|
|
class ChannelConnectionResponse(BaseModel):
|
|
id: str
|
|
provider: str
|
|
status: str
|
|
external_account_id: str | None = None
|
|
external_account_name: str | None = None
|
|
workspace_id: str | None = None
|
|
workspace_name: str | None = None
|
|
scopes: list[str] = Field(default_factory=list)
|
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
class ChannelConnectionsResponse(BaseModel):
|
|
connections: list[ChannelConnectionResponse]
|
|
|
|
|
|
class ChannelConnectResponse(BaseModel):
|
|
provider: str
|
|
mode: str
|
|
url: str | None = None
|
|
code: str
|
|
instruction: str
|
|
expires_in: int
|
|
|
|
|
|
class ChannelRuntimeConfigRequest(BaseModel):
|
|
values: dict[str, str] = Field(default_factory=dict)
|
|
|
|
|
|
_PROVIDER_META: dict[str, dict[str, str]] = {
|
|
"telegram": {"display_name": "Telegram", "auth_mode": "deep_link"},
|
|
"slack": {"display_name": "Slack", "auth_mode": "binding_code"},
|
|
"discord": {"display_name": "Discord", "auth_mode": "binding_code"},
|
|
"feishu": {"display_name": "Feishu", "auth_mode": "binding_code"},
|
|
"dingtalk": {"display_name": "DingTalk", "auth_mode": "binding_code"},
|
|
"wechat": {"display_name": "WeChat", "auth_mode": "binding_code"},
|
|
"wecom": {"display_name": "WeCom", "auth_mode": "binding_code"},
|
|
}
|
|
|
|
_CREDENTIAL_FIELDS: dict[str, tuple[dict[str, str], ...]] = {
|
|
"telegram": (
|
|
{"name": "bot_token", "label": "Bot token", "type": "password"},
|
|
{"name": "bot_username", "label": "Bot username", "type": "text"},
|
|
),
|
|
"slack": (
|
|
{"name": "bot_token", "label": "Bot token", "type": "password"},
|
|
{"name": "app_token", "label": "App token", "type": "password"},
|
|
),
|
|
"discord": ({"name": "bot_token", "label": "Bot token", "type": "password"},),
|
|
"feishu": (
|
|
{"name": "app_id", "label": "App ID", "type": "text"},
|
|
{"name": "app_secret", "label": "App secret", "type": "password"},
|
|
),
|
|
"dingtalk": (
|
|
{"name": "client_id", "label": "Client ID", "type": "text"},
|
|
{"name": "client_secret", "label": "Client secret", "type": "password"},
|
|
),
|
|
"wechat": ({"name": "bot_token", "label": "Bot token", "type": "password"},),
|
|
"wecom": (
|
|
{"name": "bot_id", "label": "Bot ID", "type": "text"},
|
|
{"name": "bot_secret", "label": "Bot secret", "type": "password"},
|
|
),
|
|
}
|
|
|
|
_RUNTIME_REQUIREMENTS: dict[str, tuple[str, ...]] = {
|
|
"telegram": ("bot_token",),
|
|
"slack": ("bot_token", "app_token"),
|
|
"discord": ("bot_token",),
|
|
"feishu": ("app_id", "app_secret"),
|
|
"dingtalk": ("client_id", "client_secret"),
|
|
"wechat": ("bot_token",),
|
|
"wecom": ("bot_id", "bot_secret"),
|
|
}
|
|
|
|
|
|
def _get_user_id(request: Request) -> str:
|
|
user = getattr(request.state, "user", None)
|
|
if user is None:
|
|
raise HTTPException(status_code=401, detail="Authentication required")
|
|
return str(user.id)
|
|
|
|
|
|
def _get_app_config():
|
|
from deerflow.config.app_config import get_app_config
|
|
|
|
return get_app_config()
|
|
|
|
|
|
def _get_channel_connections_config(request: Request) -> ChannelConnectionsConfig:
|
|
config = getattr(request.app.state, "channel_connections_config", None)
|
|
if isinstance(config, ChannelConnectionsConfig):
|
|
return config
|
|
return _get_app_config().channel_connections
|
|
|
|
|
|
def _get_channels_config(request: Request) -> dict[str, Any]:
|
|
state_config = getattr(request.app.state, "channels_config", None)
|
|
if isinstance(state_config, dict):
|
|
return state_config
|
|
|
|
app_config = _get_app_config()
|
|
extra = app_config.model_extra or {}
|
|
channels_config = extra.get("channels")
|
|
return dict(channels_config) if isinstance(channels_config, dict) else {}
|
|
|
|
|
|
def _get_repository(request: Request, config: ChannelConnectionsConfig) -> ChannelConnectionRepository:
|
|
repo = getattr(request.app.state, "channel_connection_repo", None)
|
|
if isinstance(repo, ChannelConnectionRepository):
|
|
return repo
|
|
|
|
sf = get_session_factory()
|
|
if sf is None:
|
|
raise HTTPException(status_code=503, detail="Channel connection persistence is not available")
|
|
|
|
repo = ChannelConnectionRepository(sf)
|
|
request.app.state.channel_connection_repo = repo
|
|
return repo
|
|
|
|
|
|
def _provider_config(config: ChannelConnectionsConfig, provider: str):
|
|
provider_config = getattr(config, provider, None)
|
|
if provider_config is None:
|
|
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
|
return provider_config
|
|
|
|
|
|
def _runtime_channel_configured(provider: str, channels_config: dict[str, Any]) -> bool:
|
|
runtime_config = channels_config.get(provider)
|
|
if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False):
|
|
return False
|
|
return all(str(runtime_config.get(key) or "").strip() for key in _RUNTIME_REQUIREMENTS[provider])
|
|
|
|
|
|
def _runtime_unavailable_reason(provider: str) -> str:
|
|
meta = _PROVIDER_META.get(provider)
|
|
display_name = meta["display_name"] if meta else provider
|
|
return f"Enter the required {display_name} credentials to connect this channel."
|
|
|
|
|
|
def _provider_unavailable_reason(
|
|
config: ChannelConnectionsConfig,
|
|
channels_config: dict[str, Any],
|
|
provider: str,
|
|
) -> str | None:
|
|
provider_config = _provider_config(config, provider)
|
|
if not provider_config.enabled:
|
|
return None
|
|
if not provider_config.configured:
|
|
return _runtime_unavailable_reason(provider)
|
|
if not _runtime_channel_configured(provider, channels_config):
|
|
return _runtime_unavailable_reason(provider)
|
|
return None
|
|
|
|
|
|
def _provider_status(
|
|
config: ChannelConnectionsConfig,
|
|
channels_config: dict[str, Any],
|
|
provider: str,
|
|
) -> tuple[dict[str, bool], str | None]:
|
|
declared = config.provider_status(provider)
|
|
unavailable_reason = _provider_unavailable_reason(config, channels_config, provider)
|
|
configured = declared["configured"] and _runtime_channel_configured(provider, channels_config)
|
|
return {"enabled": declared["enabled"], "configured": configured}, unavailable_reason
|
|
|
|
|
|
def _new_binding_code() -> str:
|
|
return secrets.token_urlsafe(16)
|
|
|
|
|
|
async def _create_state(
|
|
repo: ChannelConnectionRepository,
|
|
*,
|
|
owner_user_id: str,
|
|
provider: str,
|
|
) -> str:
|
|
state = _new_binding_code()
|
|
await repo.create_oauth_state(
|
|
owner_user_id=owner_user_id,
|
|
provider=provider,
|
|
state=state,
|
|
expires_at=datetime.now(UTC) + timedelta(seconds=_STATE_TTL_SECONDS),
|
|
)
|
|
return state
|
|
|
|
|
|
def _connect_instruction(provider: str, code: str) -> str:
|
|
if provider == "telegram":
|
|
return f"Send /start {code} to the DeerFlow Telegram bot."
|
|
meta = _PROVIDER_META.get(provider)
|
|
if meta is None:
|
|
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
|
return f"Send /connect {code} to the DeerFlow {meta['display_name']} bot."
|
|
|
|
|
|
def _connect_url(config: ChannelConnectionsConfig, provider: str, code: str) -> str | None:
|
|
if provider == "telegram":
|
|
provider_config = _provider_config(config, provider)
|
|
return f"https://t.me/{provider_config.bot_username}?start={code}"
|
|
if _PROVIDER_META.get(provider, {}).get("auth_mode") == "binding_code":
|
|
return None
|
|
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
|
|
|
|
|
def _connection_updated_at(connection: dict[str, Any]) -> datetime:
|
|
value = connection.get("updated_at")
|
|
if isinstance(value, datetime):
|
|
return value if value.tzinfo is not None else value.replace(tzinfo=UTC)
|
|
if isinstance(value, str) and value:
|
|
try:
|
|
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
|
except ValueError:
|
|
pass
|
|
return datetime.min.replace(tzinfo=UTC)
|
|
|
|
|
|
def _newest_connection_by_provider(connections: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
|
by_provider: dict[str, dict[str, Any]] = {}
|
|
for item in connections:
|
|
existing = by_provider.get(item["provider"])
|
|
if existing is None or _connection_updated_at(item) > _connection_updated_at(existing):
|
|
by_provider[item["provider"]] = item
|
|
return by_provider
|
|
|
|
|
|
def _credential_fields(provider: str) -> list[ChannelCredentialFieldResponse]:
|
|
fields = _CREDENTIAL_FIELDS.get(provider)
|
|
if fields is None:
|
|
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
|
return [ChannelCredentialFieldResponse(**field) for field in fields]
|
|
|
|
|
|
def _provider_response(
|
|
config: ChannelConnectionsConfig,
|
|
channels_config: dict[str, Any],
|
|
provider: str,
|
|
meta: dict[str, str],
|
|
connection: dict[str, Any] | None = None,
|
|
) -> ChannelProviderResponse:
|
|
status, unavailable_reason = _provider_status(config, channels_config, provider)
|
|
return ChannelProviderResponse(
|
|
provider=provider,
|
|
display_name=meta["display_name"],
|
|
enabled=status["enabled"],
|
|
configured=status["configured"],
|
|
connectable=status["enabled"] and status["configured"] and unavailable_reason is None,
|
|
unavailable_reason=unavailable_reason,
|
|
auth_mode=meta["auth_mode"],
|
|
connection_status=connection["status"] if connection else "not_connected",
|
|
credential_fields=_credential_fields(provider),
|
|
)
|
|
|
|
|
|
def _required_runtime_values(provider: str, values: dict[str, str]) -> dict[str, str]:
|
|
fields = _credential_fields(provider)
|
|
cleaned: dict[str, str] = {}
|
|
missing: list[str] = []
|
|
for field in fields:
|
|
raw_value = values.get(field.name, "")
|
|
value = raw_value.strip() if isinstance(raw_value, str) else str(raw_value or "").strip()
|
|
if field.required and not value:
|
|
missing.append(field.label)
|
|
cleaned[field.name] = value
|
|
if missing:
|
|
raise HTTPException(status_code=400, detail=f"Missing required channel configuration: {', '.join(missing)}")
|
|
return cleaned
|
|
|
|
|
|
async def _restart_runtime_channel_if_available(provider: str, runtime_config: dict[str, Any]) -> bool | None:
|
|
try:
|
|
from app.channels.service import get_channel_service
|
|
except Exception:
|
|
logger.exception("Failed to import channel service while configuring %s", provider)
|
|
return None
|
|
|
|
service = get_channel_service()
|
|
if service is None:
|
|
return None
|
|
return await service.configure_channel(provider, runtime_config)
|
|
|
|
|
|
@router.get("/providers", response_model=ChannelProvidersResponse)
|
|
async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
|
config = _get_channel_connections_config(request)
|
|
channels_config = _get_channels_config(request)
|
|
repo = None
|
|
if config.enabled:
|
|
try:
|
|
repo = _get_repository(request, config)
|
|
except HTTPException as exc:
|
|
if exc.status_code != 503:
|
|
raise
|
|
owner_user_id = _get_user_id(request)
|
|
connections = await repo.list_connections(owner_user_id) if repo is not None else []
|
|
by_provider = _newest_connection_by_provider(connections)
|
|
|
|
providers: list[ChannelProviderResponse] = []
|
|
for provider, meta in _PROVIDER_META.items():
|
|
if not config.provider_status(provider)["enabled"]:
|
|
continue
|
|
connection = by_provider.get(provider)
|
|
providers.append(_provider_response(config, channels_config, provider, meta, connection))
|
|
return ChannelProvidersResponse(enabled=config.enabled, providers=providers)
|
|
|
|
|
|
@router.get("/connections", response_model=ChannelConnectionsResponse)
|
|
async def get_channel_connections(request: Request) -> ChannelConnectionsResponse:
|
|
config = _get_channel_connections_config(request)
|
|
if not config.enabled:
|
|
return ChannelConnectionsResponse(connections=[])
|
|
repo = _get_repository(request, config)
|
|
rows = await repo.list_connections(_get_user_id(request))
|
|
return ChannelConnectionsResponse(connections=[ChannelConnectionResponse(**row) for row in rows])
|
|
|
|
|
|
@router.delete("/connections/{connection_id}", status_code=204)
|
|
async def disconnect_channel_connection(connection_id: str, request: Request) -> Response:
|
|
config = _get_channel_connections_config(request)
|
|
if not config.enabled:
|
|
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
|
|
|
repo = _get_repository(request, config)
|
|
disconnected = await repo.disconnect_connection(
|
|
connection_id=connection_id,
|
|
owner_user_id=_get_user_id(request),
|
|
)
|
|
if not disconnected:
|
|
raise HTTPException(status_code=404, detail="Channel connection not found")
|
|
return Response(status_code=204)
|
|
|
|
|
|
@router.post("/{provider}/connect", response_model=ChannelConnectResponse)
|
|
async def connect_channel_provider(provider: str, request: Request) -> ChannelConnectResponse:
|
|
config = _get_channel_connections_config(request)
|
|
channels_config = _get_channels_config(request)
|
|
if not config.enabled:
|
|
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
|
|
|
status, unavailable_reason = _provider_status(config, channels_config, provider)
|
|
if not status["enabled"]:
|
|
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
|
if unavailable_reason:
|
|
raise HTTPException(status_code=400, detail=unavailable_reason)
|
|
if not status["configured"]:
|
|
raise HTTPException(status_code=400, detail="Channel provider is not configured")
|
|
|
|
repo = _get_repository(request, config)
|
|
code = await _create_state(
|
|
repo,
|
|
owner_user_id=_get_user_id(request),
|
|
provider=provider,
|
|
)
|
|
return ChannelConnectResponse(
|
|
provider=provider,
|
|
mode=_PROVIDER_META[provider]["auth_mode"],
|
|
url=_connect_url(config, provider, code),
|
|
code=code,
|
|
instruction=_connect_instruction(provider, code),
|
|
expires_in=_STATE_TTL_SECONDS,
|
|
)
|
|
|
|
|
|
@router.post("/{provider}/runtime-config", response_model=ChannelProviderResponse)
|
|
async def configure_channel_provider_runtime(
|
|
provider: str,
|
|
body: ChannelRuntimeConfigRequest,
|
|
request: Request,
|
|
) -> ChannelProviderResponse:
|
|
config = _get_channel_connections_config(request)
|
|
if not config.enabled:
|
|
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
|
|
|
provider_config = _provider_config(config, provider)
|
|
if not provider_config.enabled:
|
|
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
|
|
|
values = _required_runtime_values(provider, body.values)
|
|
channels_config = _get_channels_config(request)
|
|
existing = channels_config.get(provider)
|
|
runtime_config = dict(existing) if isinstance(existing, dict) else {}
|
|
runtime_config["enabled"] = True
|
|
|
|
for key in _RUNTIME_REQUIREMENTS[provider]:
|
|
runtime_config[key] = values[key]
|
|
|
|
if provider == "telegram":
|
|
provider_config.bot_username = values["bot_username"]
|
|
request.app.state.channel_connections_config = config
|
|
|
|
channels_config[provider] = runtime_config
|
|
request.app.state.channels_config = channels_config
|
|
|
|
started = await _restart_runtime_channel_if_available(provider, runtime_config)
|
|
if started is False:
|
|
display_name = _PROVIDER_META[provider]["display_name"]
|
|
raise HTTPException(status_code=400, detail=f"Failed to start {display_name} channel. Check the values and try again.")
|
|
|
|
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
|