diff --git a/backend/app/channels/runtime_config_store.py b/backend/app/channels/runtime_config_store.py index 6305806cc..1f6ed756c 100644 --- a/backend/app/channels/runtime_config_store.py +++ b/backend/app/channels/runtime_config_store.py @@ -75,6 +75,14 @@ class ChannelRuntimeConfigStore: self._data[provider] = dict(config) self._save() + def remove_provider_config(self, provider: str) -> bool: + with self._lock: + if provider not in self._data: + return False + del self._data[provider] + self._save() + return True + def _provider_enabled(channel_connections_config: Any, provider: str) -> bool: provider_config = getattr(channel_connections_config, provider, None) diff --git a/backend/app/channels/service.py b/backend/app/channels/service.py index c742ced5b..8481e4621 100644 --- a/backend/app/channels/service.py +++ b/backend/app/channels/service.py @@ -186,6 +186,20 @@ class ChannelService: return True return await self.restart_channel(name) + async def remove_channel(self, name: str) -> bool: + """Remove runtime config for a channel and stop it if currently running.""" + self._config.pop(name, None) + channel = self._channels.pop(name, None) + if channel is None: + return True + try: + await channel.stop() + logger.info("Channel %s stopped and removed", name) + return True + except Exception: + logger.exception("Error stopping channel %s for removal", name) + return False + async def _start_channel(self, name: str, config: dict[str, Any]) -> bool: """Instantiate and start a single channel.""" import_path = _CHANNEL_REGISTRY.get(name) diff --git a/backend/app/gateway/routers/channel_connections.py b/backend/app/gateway/routers/channel_connections.py index 2ed9aaa1f..e9366af3e 100644 --- a/backend/app/gateway/routers/channel_connections.py +++ b/backend/app/gateway/routers/channel_connections.py @@ -160,16 +160,21 @@ def _get_channels_config(request: Request) -> dict[str, Any]: if isinstance(state_config, dict): return state_config + result = _load_channels_config(request, _get_channel_connections_config(request)) + request.app.state.channels_config = result + return result + + +def _load_channels_config(request: Request, config: ChannelConnectionsConfig) -> dict[str, Any]: app_config = _get_app_config() extra = app_config.model_extra or {} channels_config = extra.get("channels") result = dict(channels_config) if isinstance(channels_config, dict) else {} merge_runtime_channel_configs( result, - _get_channel_connections_config(request), + config, store=_get_runtime_config_store(request), ) - request.app.state.channels_config = result return result @@ -354,6 +359,23 @@ async def _restart_runtime_channel_if_available(provider: str, runtime_config: d return await service.configure_channel(provider, runtime_config) +async def _sync_runtime_channel_after_removal(provider: str, channels_config: dict[str, Any]) -> bool | None: + try: + from app.channels.service import get_channel_service + except Exception: + logger.exception("Failed to import channel service while disconnecting %s", provider) + return None + + service = get_channel_service() + if service is None: + return None + + runtime_config = channels_config.get(provider) + if isinstance(runtime_config, dict) and runtime_config.get("enabled", False): + return await service.configure_channel(provider, runtime_config) + return await service.remove_channel(provider) + + @router.get("/providers", response_model=ChannelProvidersResponse) async def get_channel_providers(request: Request) -> ChannelProvidersResponse: config = _get_channel_connections_config(request) @@ -404,6 +426,44 @@ async def disconnect_channel_connection(connection_id: str, request: Request) -> return Response(status_code=204) +@router.delete("/{provider}/runtime-config", response_model=ChannelProviderResponse) +async def disconnect_channel_provider_runtime(provider: str, request: Request) -> ChannelProviderResponse: + config = _get_channel_connections_config(request) + if not config.enabled: + raise HTTPException(status_code=400, detail="Channel connections are disabled") + + provider_config = _provider_config(config, provider) + if not provider_config.enabled: + raise HTTPException(status_code=400, detail="Channel provider is not enabled") + + owner_user_id = _get_user_id(request) + try: + repo = _get_repository(request, config) + except HTTPException as exc: + if exc.status_code != 503: + raise + repo = None + + if repo is not None: + for connection in await repo.list_connections(owner_user_id): + if connection["provider"] == provider and connection["status"] != "revoked": + await repo.disconnect_connection( + connection_id=connection["id"], + owner_user_id=owner_user_id, + ) + + _get_runtime_config_store(request).remove_provider_config(provider) + channels_config = _load_channels_config(request, config) + request.app.state.channels_config = channels_config + + stopped = await _sync_runtime_channel_after_removal(provider, channels_config) + if stopped is False: + display_name = _PROVIDER_META[provider]["display_name"] + raise HTTPException(status_code=400, detail=f"Failed to stop {display_name} channel. Try again.") + + return _provider_response(config, channels_config, provider, _PROVIDER_META[provider]) + + @router.post("/{provider}/connect", response_model=ChannelConnectResponse) async def connect_channel_provider(provider: str, request: Request) -> ChannelConnectResponse: config = _get_channel_connections_config(request) diff --git a/backend/tests/test_channel_connections_router.py b/backend/tests/test_channel_connections_router.py index 239169d06..4944e2420 100644 --- a/backend/tests/test_channel_connections_router.py +++ b/backend/tests/test_channel_connections_router.py @@ -459,6 +459,84 @@ def test_configure_provider_runtime_credentials_survive_local_restart(tmp_path): anyio.run(repo.close) +def test_disconnect_provider_runtime_config_clears_connected_state(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + config = ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "slack": {"enabled": True}, + } + ) + runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json") + app = _make_app(config, repo, {}, runtime_config_store=runtime_config_store) + + with TestClient(app) as client: + configure_response = client.post( + "/api/channels/slack/runtime-config", + json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}}, + ) + disconnect_response = client.delete("/api/channels/slack/runtime-config") + providers_response = client.get("/api/channels/providers") + + assert configure_response.status_code == 200 + assert disconnect_response.status_code == 200 + disconnected = disconnect_response.json() + assert disconnected["provider"] == "slack" + assert disconnected["configured"] is False + assert disconnected["connectable"] is False + assert disconnected["connection_status"] == "not_connected" + assert runtime_config_store.get_provider_config("slack") is None + + assert providers_response.status_code == 200 + by_provider = {item["provider"]: item for item in providers_response.json()["providers"]} + assert by_provider["slack"]["connection_status"] == "not_connected" + + anyio.run(repo.close) + + +def test_disconnect_provider_runtime_config_revokes_current_user_provider_connections(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + + async def seed_connection(): + await repo.upsert_connection( + owner_user_id=str(_user().id), + provider="slack", + external_account_id="U123", + status="connected", + ) + + anyio.run(seed_connection) + config = ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "slack": {"enabled": True}, + } + ) + runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json") + app = _make_app(config, repo, {}, runtime_config_store=runtime_config_store) + + with TestClient(app) as client: + configure_response = client.post( + "/api/channels/slack/runtime-config", + json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}}, + ) + disconnect_response = client.delete("/api/channels/slack/runtime-config") + + assert configure_response.status_code == 200 + assert disconnect_response.status_code == 200 + + async def get_connection_status(): + return (await repo.list_connections(str(_user().id)))[0]["status"] + + assert anyio.run(get_connection_status) == "revoked" + + anyio.run(repo.close) + + def test_disconnect_connection_revokes_current_user_connection(tmp_path): import anyio diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index 07e3aa8be..08d9d4616 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -3400,6 +3400,31 @@ class TestChannelService: assert service.manager._connection_repo is repo + def test_remove_channel_stops_running_channel_and_forgets_config(self): + from app.channels.service import ChannelService + + async def go(): + service = ChannelService( + channels_config={ + "slack": { + "enabled": True, + "bot_token": "xoxb-ui", + "app_token": "xapp-ui", + }, + } + ) + channel = AsyncMock() + service._channels["slack"] = channel + service._running = True + + assert await service.remove_channel("slack") is True + + channel.stop.assert_awaited_once() + assert "slack" not in service._channels + assert "slack" not in service._config + + _run(go()) + def test_disabled_channel_with_string_creds_emits_warning(self, caplog): """Warning is emitted when a channel has string credentials but enabled=false.""" import logging diff --git a/frontend/src/components/workspace/settings/channels-settings-page.tsx b/frontend/src/components/workspace/settings/channels-settings-page.tsx index 8934043a7..13d07d138 100644 --- a/frontend/src/components/workspace/settings/channels-settings-page.tsx +++ b/frontend/src/components/workspace/settings/channels-settings-page.tsx @@ -25,7 +25,7 @@ import { useChannelConnections, useChannelProviders, useConnectChannelProvider, - useDisconnectChannelConnection, + useDisconnectChannelProvider, } from "@/core/channels/hooks"; import { closeConnectWindow, @@ -122,7 +122,7 @@ function ChannelProviderItem({ const { t } = useI18n(); const connectMutation = useConnectChannelProvider(); const configureMutation = useConfigureChannelProvider(); - const disconnectMutation = useDisconnectChannelConnection(); + const disconnectProviderMutation = useDisconnectChannelProvider(); const [setupOpen, setSetupOpen] = useState(false); const isConnected = connection?.status === "connected" || @@ -137,8 +137,8 @@ function ChannelProviderItem({ (configureMutation.isPending && configureMutation.variables?.provider === provider.provider); const isDisconnecting = - disconnectMutation.isPending && - disconnectMutation.variables === connection?.id; + disconnectProviderMutation.isPending && + disconnectProviderMutation.variables === provider.provider; const connectionLabel = connection ? getConnectionLabel(connection) : null; const statusLabel = getStatusLabel(provider, connection, t); const unavailableReason = getProviderUnavailableReason(provider, t); @@ -209,7 +209,7 @@ function ChannelProviderItem({ type="button" variant="outline" size="sm" - disabled={isConnecting} + disabled={isConnecting || isDisconnecting} onClick={() => setSetupOpen(true)} > {isConnecting ? ( @@ -220,22 +220,33 @@ function ChannelProviderItem({ {t.channels.modify} ) : null} - {connection ? ( - - ) : null} + ) : (