mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 18:05:58 +00:00
Ensure runtime IM channels are ready after restart
This commit is contained in:
@@ -43,6 +43,11 @@ _CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
|
|||||||
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
|
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
|
||||||
|
|
||||||
|
|
||||||
|
def _channel_has_credentials(name: str, channel_config: dict[str, Any]) -> bool:
|
||||||
|
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
|
||||||
|
return any(not isinstance(channel_config.get(key), bool) and channel_config.get(key) is not None and str(channel_config[key]).strip() for key in cred_keys)
|
||||||
|
|
||||||
|
|
||||||
def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str, default: str) -> str:
|
def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str, default: str) -> str:
|
||||||
value = config.pop(config_key, None)
|
value = config.pop(config_key, None)
|
||||||
if isinstance(value, str) and value.strip():
|
if isinstance(value, str) and value.strip():
|
||||||
@@ -127,14 +132,20 @@ class ChannelService:
|
|||||||
return
|
return
|
||||||
|
|
||||||
await self.manager.start()
|
await self.manager.start()
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
ready_status = await self.ensure_ready_channels(attempts=2)
|
||||||
|
ready_count = sum(1 for ready in ready_status.values() if ready)
|
||||||
|
logger.info("ChannelService started with %d/%d ready channels", ready_count, len(ready_status))
|
||||||
|
|
||||||
|
async def ensure_ready_channels(self, *, attempts: int = 1) -> dict[str, bool]:
|
||||||
|
"""Start or restart enabled configured channels that are not ready."""
|
||||||
|
ready_status: dict[str, bool] = {}
|
||||||
for name, channel_config in self._config.items():
|
for name, channel_config in self._config.items():
|
||||||
if not isinstance(channel_config, dict):
|
if not isinstance(channel_config, dict):
|
||||||
continue
|
continue
|
||||||
if not channel_config.get("enabled", False):
|
if not channel_config.get("enabled", False):
|
||||||
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
|
if _channel_has_credentials(name, channel_config):
|
||||||
has_creds = any(not isinstance(channel_config.get(k), bool) and channel_config.get(k) is not None and str(channel_config[k]).strip() for k in cred_keys)
|
|
||||||
if has_creds:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"A configured channel has credentials configured but is disabled. Set enabled: true under its channels entry in config.yaml to activate it.",
|
"A configured channel has credentials configured but is disabled. Set enabled: true under its channels entry in config.yaml to activate it.",
|
||||||
)
|
)
|
||||||
@@ -142,10 +153,49 @@ class ChannelService:
|
|||||||
logger.info("A configured channel is disabled, skipping")
|
logger.info("A configured channel is disabled, skipping")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
await self._start_channel(name, channel_config)
|
ready_status[name] = await self.ensure_channel_ready(name, attempts=attempts)
|
||||||
|
return ready_status
|
||||||
|
|
||||||
self._running = True
|
async def ensure_channel_ready(
|
||||||
logger.info("ChannelService started with %d channels", len(self._channels))
|
self,
|
||||||
|
name: str,
|
||||||
|
config: dict[str, Any] | None = None,
|
||||||
|
*,
|
||||||
|
attempts: int = 1,
|
||||||
|
) -> bool:
|
||||||
|
"""Ensure a single enabled channel is running using its current config."""
|
||||||
|
if not self._running:
|
||||||
|
logger.warning("ChannelService is not running; cannot ensure channel readiness")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if config is not None:
|
||||||
|
self._config[name] = dict(config)
|
||||||
|
|
||||||
|
channel_config = self._config.get(name)
|
||||||
|
if not channel_config or not isinstance(channel_config, dict):
|
||||||
|
logger.warning("No config for requested channel")
|
||||||
|
return False
|
||||||
|
if not channel_config.get("enabled", False):
|
||||||
|
return False
|
||||||
|
|
||||||
|
channel = self._channels.get(name)
|
||||||
|
if channel is not None and channel.is_running:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if channel is not None:
|
||||||
|
try:
|
||||||
|
await channel.stop()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error stopping non-running channel before readiness retry")
|
||||||
|
self._channels.pop(name, None)
|
||||||
|
|
||||||
|
max_attempts = max(1, attempts)
|
||||||
|
for attempt in range(max_attempts):
|
||||||
|
if attempt > 0:
|
||||||
|
logger.info("Retrying channel startup after readiness check")
|
||||||
|
if await self._start_channel(name, channel_config):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop all channels and the manager."""
|
"""Stop all channels and the manager."""
|
||||||
|
|||||||
@@ -217,7 +217,7 @@ def _runtime_unavailable_reason(provider: str) -> str:
|
|||||||
def _runtime_not_running_reason(provider: str) -> str:
|
def _runtime_not_running_reason(provider: str) -> str:
|
||||||
meta = _PROVIDER_META.get(provider)
|
meta = _PROVIDER_META.get(provider)
|
||||||
display_name = meta["display_name"] if meta else provider
|
display_name = meta["display_name"] if meta else provider
|
||||||
return f"{display_name} channel is configured but is not running. Check the credentials and save this channel again."
|
return f"{display_name} channel is configured but is not running. Check the credentials and service logs."
|
||||||
|
|
||||||
|
|
||||||
def _runtime_channel_running(provider: str) -> bool | None:
|
def _runtime_channel_running(provider: str) -> bool | None:
|
||||||
@@ -244,6 +244,35 @@ def _runtime_channel_running(provider: str) -> bool | None:
|
|||||||
return bool(channel_status.get("running"))
|
return bool(channel_status.get("running"))
|
||||||
|
|
||||||
|
|
||||||
|
async def _ensure_runtime_channel_ready_if_available(
|
||||||
|
provider: str,
|
||||||
|
channels_config: dict[str, Any],
|
||||||
|
) -> bool | None:
|
||||||
|
runtime_config = channels_config.get(provider)
|
||||||
|
if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.channels.service import get_channel_service
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Unable to import channel service for readiness reconciliation", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
service = get_channel_service()
|
||||||
|
if service is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
ensure_channel_ready = getattr(service, "ensure_channel_ready", None)
|
||||||
|
if ensure_channel_ready is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await ensure_channel_ready(provider, runtime_config)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to reconcile runtime channel readiness")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _provider_unavailable_reason(
|
def _provider_unavailable_reason(
|
||||||
config: ChannelConnectionsConfig,
|
config: ChannelConnectionsConfig,
|
||||||
channels_config: dict[str, Any],
|
channels_config: dict[str, Any],
|
||||||
@@ -459,6 +488,8 @@ async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
|||||||
for provider, meta in _PROVIDER_META.items():
|
for provider, meta in _PROVIDER_META.items():
|
||||||
if not config.provider_status(provider)["enabled"]:
|
if not config.provider_status(provider)["enabled"]:
|
||||||
continue
|
continue
|
||||||
|
if _runtime_channel_configured(provider, channels_config):
|
||||||
|
await _ensure_runtime_channel_ready_if_available(provider, channels_config)
|
||||||
connection = by_provider.get(provider)
|
connection = by_provider.get(provider)
|
||||||
providers.append(_provider_response(config, channels_config, provider, meta, connection))
|
providers.append(_provider_response(config, channels_config, provider, meta, connection))
|
||||||
return ChannelProvidersResponse(enabled=config.enabled, providers=providers)
|
return ChannelProvidersResponse(enabled=config.enabled, providers=providers)
|
||||||
@@ -535,6 +566,10 @@ async def connect_channel_provider(provider: str, request: Request) -> ChannelCo
|
|||||||
if not config.enabled:
|
if not config.enabled:
|
||||||
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||||
|
|
||||||
|
provider_config = _provider_config(config, provider)
|
||||||
|
if provider_config.enabled and _runtime_channel_configured(provider, channels_config):
|
||||||
|
await _ensure_runtime_channel_ready_if_available(provider, channels_config)
|
||||||
|
|
||||||
status, unavailable_reason = _provider_status(config, channels_config, provider)
|
status, unavailable_reason = _provider_status(config, channels_config, provider)
|
||||||
if not status["enabled"]:
|
if not status["enabled"]:
|
||||||
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
||||||
|
|||||||
@@ -247,6 +247,60 @@ def test_get_providers_reports_configured_channel_not_running(tmp_path, monkeypa
|
|||||||
anyio.run(repo.close)
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_providers_restarts_configured_channel_when_service_can_reconcile(tmp_path, monkeypatch):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
config = ChannelConnectionsConfig.model_validate(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"feishu": {"enabled": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
channels_config = {
|
||||||
|
"feishu": {
|
||||||
|
"enabled": True,
|
||||||
|
"app_id": "feishu-app",
|
||||||
|
"app_secret": "feishu-secret",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
app = _make_app(config, repo, channels_config)
|
||||||
|
status = {
|
||||||
|
"service_running": True,
|
||||||
|
"channels": {
|
||||||
|
"feishu": {
|
||||||
|
"enabled": True,
|
||||||
|
"running": False,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
reconciled: list[tuple[str, dict]] = []
|
||||||
|
|
||||||
|
async def ensure_channel_ready(provider, runtime_config):
|
||||||
|
reconciled.append((provider, dict(runtime_config)))
|
||||||
|
status["channels"][provider]["running"] = True
|
||||||
|
return True
|
||||||
|
|
||||||
|
service = SimpleNamespace(
|
||||||
|
get_status=lambda: status,
|
||||||
|
ensure_channel_ready=ensure_channel_ready,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("app.channels.service.get_channel_service", lambda: service)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/channels/providers")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
by_provider = {item["provider"]: item for item in response.json()["providers"]}
|
||||||
|
assert by_provider["feishu"]["configured"] is True
|
||||||
|
assert by_provider["feishu"]["connectable"] is True
|
||||||
|
assert by_provider["feishu"]["connection_status"] == "connected"
|
||||||
|
assert by_provider["feishu"]["unavailable_reason"] is None
|
||||||
|
assert reconciled == [("feishu", channels_config["feishu"])]
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
def test_get_providers_uses_newest_connection_status_per_provider(tmp_path):
|
def test_get_providers_uses_newest_connection_status_per_provider(tmp_path):
|
||||||
import anyio
|
import anyio
|
||||||
|
|
||||||
|
|||||||
@@ -3504,6 +3504,51 @@ class TestChannelService:
|
|||||||
"app_token": "xapp-ui",
|
"app_token": "xapp-ui",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def test_start_retries_configured_channel_until_ready(self, monkeypatch):
|
||||||
|
from app.channels.service import ChannelService
|
||||||
|
|
||||||
|
class FlakyReadyChannel(Channel):
|
||||||
|
starts = 0
|
||||||
|
|
||||||
|
def __init__(self, bus, config):
|
||||||
|
super().__init__(name="slack", bus=bus, config=config)
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
type(self).starts += 1
|
||||||
|
self._running = type(self).starts >= 2
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
async def send(self, msg):
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"deerflow.reflection.resolve_class",
|
||||||
|
lambda import_path, base_class=None: FlakyReadyChannel,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
service = ChannelService(
|
||||||
|
channels_config={
|
||||||
|
"slack": {
|
||||||
|
"enabled": True,
|
||||||
|
"bot_token": "xoxb-ui",
|
||||||
|
"app_token": "xapp-ui",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await service.start()
|
||||||
|
|
||||||
|
assert FlakyReadyChannel.starts == 2
|
||||||
|
assert service.get_status()["channels"]["slack"]["running"] is True
|
||||||
|
finally:
|
||||||
|
await service.stop()
|
||||||
|
|
||||||
|
_run(go())
|
||||||
|
|
||||||
def test_connection_repo_is_forwarded_to_manager(self):
|
def test_connection_repo_is_forwarded_to_manager(self):
|
||||||
from app.channels.service import ChannelService
|
from app.channels.service import ChannelService
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user