mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 18:05:58 +00:00
Ensure runtime IM channels are ready after restart
This commit is contained in:
@@ -43,6 +43,11 @@ _CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
|
||||
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
|
||||
|
||||
|
||||
def _channel_has_credentials(name: str, channel_config: dict[str, Any]) -> bool:
|
||||
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
|
||||
return any(not isinstance(channel_config.get(key), bool) and channel_config.get(key) is not None and str(channel_config[key]).strip() for key in cred_keys)
|
||||
|
||||
|
||||
def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str, default: str) -> str:
|
||||
value = config.pop(config_key, None)
|
||||
if isinstance(value, str) and value.strip():
|
||||
@@ -127,14 +132,20 @@ class ChannelService:
|
||||
return
|
||||
|
||||
await self.manager.start()
|
||||
self._running = True
|
||||
|
||||
ready_status = await self.ensure_ready_channels(attempts=2)
|
||||
ready_count = sum(1 for ready in ready_status.values() if ready)
|
||||
logger.info("ChannelService started with %d/%d ready channels", ready_count, len(ready_status))
|
||||
|
||||
async def ensure_ready_channels(self, *, attempts: int = 1) -> dict[str, bool]:
|
||||
"""Start or restart enabled configured channels that are not ready."""
|
||||
ready_status: dict[str, bool] = {}
|
||||
for name, channel_config in self._config.items():
|
||||
if not isinstance(channel_config, dict):
|
||||
continue
|
||||
if not channel_config.get("enabled", False):
|
||||
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
|
||||
has_creds = any(not isinstance(channel_config.get(k), bool) and channel_config.get(k) is not None and str(channel_config[k]).strip() for k in cred_keys)
|
||||
if has_creds:
|
||||
if _channel_has_credentials(name, channel_config):
|
||||
logger.warning(
|
||||
"A configured channel has credentials configured but is disabled. Set enabled: true under its channels entry in config.yaml to activate it.",
|
||||
)
|
||||
@@ -142,10 +153,49 @@ class ChannelService:
|
||||
logger.info("A configured channel is disabled, skipping")
|
||||
continue
|
||||
|
||||
await self._start_channel(name, channel_config)
|
||||
ready_status[name] = await self.ensure_channel_ready(name, attempts=attempts)
|
||||
return ready_status
|
||||
|
||||
self._running = True
|
||||
logger.info("ChannelService started with %d channels", len(self._channels))
|
||||
async def ensure_channel_ready(
|
||||
self,
|
||||
name: str,
|
||||
config: dict[str, Any] | None = None,
|
||||
*,
|
||||
attempts: int = 1,
|
||||
) -> bool:
|
||||
"""Ensure a single enabled channel is running using its current config."""
|
||||
if not self._running:
|
||||
logger.warning("ChannelService is not running; cannot ensure channel readiness")
|
||||
return False
|
||||
|
||||
if config is not None:
|
||||
self._config[name] = dict(config)
|
||||
|
||||
channel_config = self._config.get(name)
|
||||
if not channel_config or not isinstance(channel_config, dict):
|
||||
logger.warning("No config for requested channel")
|
||||
return False
|
||||
if not channel_config.get("enabled", False):
|
||||
return False
|
||||
|
||||
channel = self._channels.get(name)
|
||||
if channel is not None and channel.is_running:
|
||||
return True
|
||||
|
||||
if channel is not None:
|
||||
try:
|
||||
await channel.stop()
|
||||
except Exception:
|
||||
logger.exception("Error stopping non-running channel before readiness retry")
|
||||
self._channels.pop(name, None)
|
||||
|
||||
max_attempts = max(1, attempts)
|
||||
for attempt in range(max_attempts):
|
||||
if attempt > 0:
|
||||
logger.info("Retrying channel startup after readiness check")
|
||||
if await self._start_channel(name, channel_config):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop all channels and the manager."""
|
||||
|
||||
@@ -217,7 +217,7 @@ def _runtime_unavailable_reason(provider: str) -> str:
|
||||
def _runtime_not_running_reason(provider: str) -> str:
|
||||
meta = _PROVIDER_META.get(provider)
|
||||
display_name = meta["display_name"] if meta else provider
|
||||
return f"{display_name} channel is configured but is not running. Check the credentials and save this channel again."
|
||||
return f"{display_name} channel is configured but is not running. Check the credentials and service logs."
|
||||
|
||||
|
||||
def _runtime_channel_running(provider: str) -> bool | None:
|
||||
@@ -244,6 +244,35 @@ def _runtime_channel_running(provider: str) -> bool | None:
|
||||
return bool(channel_status.get("running"))
|
||||
|
||||
|
||||
async def _ensure_runtime_channel_ready_if_available(
|
||||
provider: str,
|
||||
channels_config: dict[str, Any],
|
||||
) -> bool | None:
|
||||
runtime_config = channels_config.get(provider)
|
||||
if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False):
|
||||
return None
|
||||
|
||||
try:
|
||||
from app.channels.service import get_channel_service
|
||||
except Exception:
|
||||
logger.debug("Unable to import channel service for readiness reconciliation", exc_info=True)
|
||||
return None
|
||||
|
||||
service = get_channel_service()
|
||||
if service is None:
|
||||
return None
|
||||
|
||||
ensure_channel_ready = getattr(service, "ensure_channel_ready", None)
|
||||
if ensure_channel_ready is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return await ensure_channel_ready(provider, runtime_config)
|
||||
except Exception:
|
||||
logger.exception("Failed to reconcile runtime channel readiness")
|
||||
return False
|
||||
|
||||
|
||||
def _provider_unavailable_reason(
|
||||
config: ChannelConnectionsConfig,
|
||||
channels_config: dict[str, Any],
|
||||
@@ -459,6 +488,8 @@ async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
||||
for provider, meta in _PROVIDER_META.items():
|
||||
if not config.provider_status(provider)["enabled"]:
|
||||
continue
|
||||
if _runtime_channel_configured(provider, channels_config):
|
||||
await _ensure_runtime_channel_ready_if_available(provider, channels_config)
|
||||
connection = by_provider.get(provider)
|
||||
providers.append(_provider_response(config, channels_config, provider, meta, connection))
|
||||
return ChannelProvidersResponse(enabled=config.enabled, providers=providers)
|
||||
@@ -535,6 +566,10 @@ async def connect_channel_provider(provider: str, request: Request) -> ChannelCo
|
||||
if not config.enabled:
|
||||
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||
|
||||
provider_config = _provider_config(config, provider)
|
||||
if provider_config.enabled and _runtime_channel_configured(provider, channels_config):
|
||||
await _ensure_runtime_channel_ready_if_available(provider, channels_config)
|
||||
|
||||
status, unavailable_reason = _provider_status(config, channels_config, provider)
|
||||
if not status["enabled"]:
|
||||
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
||||
|
||||
@@ -247,6 +247,60 @@ def test_get_providers_reports_configured_channel_not_running(tmp_path, monkeypa
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_get_providers_restarts_configured_channel_when_service_can_reconcile(tmp_path, monkeypatch):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
config = ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"feishu": {"enabled": True},
|
||||
}
|
||||
)
|
||||
channels_config = {
|
||||
"feishu": {
|
||||
"enabled": True,
|
||||
"app_id": "feishu-app",
|
||||
"app_secret": "feishu-secret",
|
||||
}
|
||||
}
|
||||
app = _make_app(config, repo, channels_config)
|
||||
status = {
|
||||
"service_running": True,
|
||||
"channels": {
|
||||
"feishu": {
|
||||
"enabled": True,
|
||||
"running": False,
|
||||
}
|
||||
},
|
||||
}
|
||||
reconciled: list[tuple[str, dict]] = []
|
||||
|
||||
async def ensure_channel_ready(provider, runtime_config):
|
||||
reconciled.append((provider, dict(runtime_config)))
|
||||
status["channels"][provider]["running"] = True
|
||||
return True
|
||||
|
||||
service = SimpleNamespace(
|
||||
get_status=lambda: status,
|
||||
ensure_channel_ready=ensure_channel_ready,
|
||||
)
|
||||
monkeypatch.setattr("app.channels.service.get_channel_service", lambda: service)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/channels/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
by_provider = {item["provider"]: item for item in response.json()["providers"]}
|
||||
assert by_provider["feishu"]["configured"] is True
|
||||
assert by_provider["feishu"]["connectable"] is True
|
||||
assert by_provider["feishu"]["connection_status"] == "connected"
|
||||
assert by_provider["feishu"]["unavailable_reason"] is None
|
||||
assert reconciled == [("feishu", channels_config["feishu"])]
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_get_providers_uses_newest_connection_status_per_provider(tmp_path):
|
||||
import anyio
|
||||
|
||||
|
||||
@@ -3504,6 +3504,51 @@ class TestChannelService:
|
||||
"app_token": "xapp-ui",
|
||||
}
|
||||
|
||||
def test_start_retries_configured_channel_until_ready(self, monkeypatch):
|
||||
from app.channels.service import ChannelService
|
||||
|
||||
class FlakyReadyChannel(Channel):
|
||||
starts = 0
|
||||
|
||||
def __init__(self, bus, config):
|
||||
super().__init__(name="slack", bus=bus, config=config)
|
||||
|
||||
async def start(self):
|
||||
type(self).starts += 1
|
||||
self._running = type(self).starts >= 2
|
||||
|
||||
async def stop(self):
|
||||
self._running = False
|
||||
|
||||
async def send(self, msg):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"deerflow.reflection.resolve_class",
|
||||
lambda import_path, base_class=None: FlakyReadyChannel,
|
||||
)
|
||||
|
||||
async def go():
|
||||
service = ChannelService(
|
||||
channels_config={
|
||||
"slack": {
|
||||
"enabled": True,
|
||||
"bot_token": "xoxb-ui",
|
||||
"app_token": "xapp-ui",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
await service.start()
|
||||
|
||||
assert FlakyReadyChannel.starts == 2
|
||||
assert service.get_status()["channels"]["slack"]["running"] is True
|
||||
finally:
|
||||
await service.stop()
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_connection_repo_is_forwarded_to_manager(self):
|
||||
from app.channels.service import ChannelService
|
||||
|
||||
|
||||
Reference in New Issue
Block a user