Allow disconnecting runtime IM channels

This commit is contained in:
taohe
2026-06-11 16:10:02 +08:00
parent ade4a55cfe
commit 4a0278420f
9 changed files with 275 additions and 23 deletions
@@ -75,6 +75,14 @@ class ChannelRuntimeConfigStore:
self._data[provider] = dict(config) self._data[provider] = dict(config)
self._save() self._save()
def remove_provider_config(self, provider: str) -> bool:
with self._lock:
if provider not in self._data:
return False
del self._data[provider]
self._save()
return True
def _provider_enabled(channel_connections_config: Any, provider: str) -> bool: def _provider_enabled(channel_connections_config: Any, provider: str) -> bool:
provider_config = getattr(channel_connections_config, provider, None) provider_config = getattr(channel_connections_config, provider, None)
+14
View File
@@ -186,6 +186,20 @@ class ChannelService:
return True return True
return await self.restart_channel(name) return await self.restart_channel(name)
async def remove_channel(self, name: str) -> bool:
"""Remove runtime config for a channel and stop it if currently running."""
self._config.pop(name, None)
channel = self._channels.pop(name, None)
if channel is None:
return True
try:
await channel.stop()
logger.info("Channel %s stopped and removed", name)
return True
except Exception:
logger.exception("Error stopping channel %s for removal", name)
return False
async def _start_channel(self, name: str, config: dict[str, Any]) -> bool: async def _start_channel(self, name: str, config: dict[str, Any]) -> bool:
"""Instantiate and start a single channel.""" """Instantiate and start a single channel."""
import_path = _CHANNEL_REGISTRY.get(name) import_path = _CHANNEL_REGISTRY.get(name)
@@ -160,16 +160,21 @@ def _get_channels_config(request: Request) -> dict[str, Any]:
if isinstance(state_config, dict): if isinstance(state_config, dict):
return state_config return state_config
result = _load_channels_config(request, _get_channel_connections_config(request))
request.app.state.channels_config = result
return result
def _load_channels_config(request: Request, config: ChannelConnectionsConfig) -> dict[str, Any]:
app_config = _get_app_config() app_config = _get_app_config()
extra = app_config.model_extra or {} extra = app_config.model_extra or {}
channels_config = extra.get("channels") channels_config = extra.get("channels")
result = dict(channels_config) if isinstance(channels_config, dict) else {} result = dict(channels_config) if isinstance(channels_config, dict) else {}
merge_runtime_channel_configs( merge_runtime_channel_configs(
result, result,
_get_channel_connections_config(request), config,
store=_get_runtime_config_store(request), store=_get_runtime_config_store(request),
) )
request.app.state.channels_config = result
return result return result
@@ -354,6 +359,23 @@ async def _restart_runtime_channel_if_available(provider: str, runtime_config: d
return await service.configure_channel(provider, runtime_config) return await service.configure_channel(provider, runtime_config)
async def _sync_runtime_channel_after_removal(provider: str, channels_config: dict[str, Any]) -> bool | None:
try:
from app.channels.service import get_channel_service
except Exception:
logger.exception("Failed to import channel service while disconnecting %s", provider)
return None
service = get_channel_service()
if service is None:
return None
runtime_config = channels_config.get(provider)
if isinstance(runtime_config, dict) and runtime_config.get("enabled", False):
return await service.configure_channel(provider, runtime_config)
return await service.remove_channel(provider)
@router.get("/providers", response_model=ChannelProvidersResponse) @router.get("/providers", response_model=ChannelProvidersResponse)
async def get_channel_providers(request: Request) -> ChannelProvidersResponse: async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
config = _get_channel_connections_config(request) config = _get_channel_connections_config(request)
@@ -404,6 +426,44 @@ async def disconnect_channel_connection(connection_id: str, request: Request) ->
return Response(status_code=204) return Response(status_code=204)
@router.delete("/{provider}/runtime-config", response_model=ChannelProviderResponse)
async def disconnect_channel_provider_runtime(provider: str, request: Request) -> ChannelProviderResponse:
config = _get_channel_connections_config(request)
if not config.enabled:
raise HTTPException(status_code=400, detail="Channel connections are disabled")
provider_config = _provider_config(config, provider)
if not provider_config.enabled:
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
owner_user_id = _get_user_id(request)
try:
repo = _get_repository(request, config)
except HTTPException as exc:
if exc.status_code != 503:
raise
repo = None
if repo is not None:
for connection in await repo.list_connections(owner_user_id):
if connection["provider"] == provider and connection["status"] != "revoked":
await repo.disconnect_connection(
connection_id=connection["id"],
owner_user_id=owner_user_id,
)
_get_runtime_config_store(request).remove_provider_config(provider)
channels_config = _load_channels_config(request, config)
request.app.state.channels_config = channels_config
stopped = await _sync_runtime_channel_after_removal(provider, channels_config)
if stopped is False:
display_name = _PROVIDER_META[provider]["display_name"]
raise HTTPException(status_code=400, detail=f"Failed to stop {display_name} channel. Try again.")
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
@router.post("/{provider}/connect", response_model=ChannelConnectResponse) @router.post("/{provider}/connect", response_model=ChannelConnectResponse)
async def connect_channel_provider(provider: str, request: Request) -> ChannelConnectResponse: async def connect_channel_provider(provider: str, request: Request) -> ChannelConnectResponse:
config = _get_channel_connections_config(request) config = _get_channel_connections_config(request)
@@ -459,6 +459,84 @@ def test_configure_provider_runtime_credentials_survive_local_restart(tmp_path):
anyio.run(repo.close) anyio.run(repo.close)
def test_disconnect_provider_runtime_config_clears_connected_state(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"slack": {"enabled": True},
}
)
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
app = _make_app(config, repo, {}, runtime_config_store=runtime_config_store)
with TestClient(app) as client:
configure_response = client.post(
"/api/channels/slack/runtime-config",
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
)
disconnect_response = client.delete("/api/channels/slack/runtime-config")
providers_response = client.get("/api/channels/providers")
assert configure_response.status_code == 200
assert disconnect_response.status_code == 200
disconnected = disconnect_response.json()
assert disconnected["provider"] == "slack"
assert disconnected["configured"] is False
assert disconnected["connectable"] is False
assert disconnected["connection_status"] == "not_connected"
assert runtime_config_store.get_provider_config("slack") is None
assert providers_response.status_code == 200
by_provider = {item["provider"]: item for item in providers_response.json()["providers"]}
assert by_provider["slack"]["connection_status"] == "not_connected"
anyio.run(repo.close)
def test_disconnect_provider_runtime_config_revokes_current_user_provider_connections(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
async def seed_connection():
await repo.upsert_connection(
owner_user_id=str(_user().id),
provider="slack",
external_account_id="U123",
status="connected",
)
anyio.run(seed_connection)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"slack": {"enabled": True},
}
)
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
app = _make_app(config, repo, {}, runtime_config_store=runtime_config_store)
with TestClient(app) as client:
configure_response = client.post(
"/api/channels/slack/runtime-config",
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
)
disconnect_response = client.delete("/api/channels/slack/runtime-config")
assert configure_response.status_code == 200
assert disconnect_response.status_code == 200
async def get_connection_status():
return (await repo.list_connections(str(_user().id)))[0]["status"]
assert anyio.run(get_connection_status) == "revoked"
anyio.run(repo.close)
def test_disconnect_connection_revokes_current_user_connection(tmp_path): def test_disconnect_connection_revokes_current_user_connection(tmp_path):
import anyio import anyio
+25
View File
@@ -3400,6 +3400,31 @@ class TestChannelService:
assert service.manager._connection_repo is repo assert service.manager._connection_repo is repo
def test_remove_channel_stops_running_channel_and_forgets_config(self):
from app.channels.service import ChannelService
async def go():
service = ChannelService(
channels_config={
"slack": {
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
},
}
)
channel = AsyncMock()
service._channels["slack"] = channel
service._running = True
assert await service.remove_channel("slack") is True
channel.stop.assert_awaited_once()
assert "slack" not in service._channels
assert "slack" not in service._config
_run(go())
def test_disabled_channel_with_string_creds_emits_warning(self, caplog): def test_disabled_channel_with_string_creds_emits_warning(self, caplog):
"""Warning is emitted when a channel has string credentials but enabled=false.""" """Warning is emitted when a channel has string credentials but enabled=false."""
import logging import logging
@@ -25,7 +25,7 @@ import {
useChannelConnections, useChannelConnections,
useChannelProviders, useChannelProviders,
useConnectChannelProvider, useConnectChannelProvider,
useDisconnectChannelConnection, useDisconnectChannelProvider,
} from "@/core/channels/hooks"; } from "@/core/channels/hooks";
import { import {
closeConnectWindow, closeConnectWindow,
@@ -122,7 +122,7 @@ function ChannelProviderItem({
const { t } = useI18n(); const { t } = useI18n();
const connectMutation = useConnectChannelProvider(); const connectMutation = useConnectChannelProvider();
const configureMutation = useConfigureChannelProvider(); const configureMutation = useConfigureChannelProvider();
const disconnectMutation = useDisconnectChannelConnection(); const disconnectProviderMutation = useDisconnectChannelProvider();
const [setupOpen, setSetupOpen] = useState(false); const [setupOpen, setSetupOpen] = useState(false);
const isConnected = const isConnected =
connection?.status === "connected" || connection?.status === "connected" ||
@@ -137,8 +137,8 @@ function ChannelProviderItem({
(configureMutation.isPending && (configureMutation.isPending &&
configureMutation.variables?.provider === provider.provider); configureMutation.variables?.provider === provider.provider);
const isDisconnecting = const isDisconnecting =
disconnectMutation.isPending && disconnectProviderMutation.isPending &&
disconnectMutation.variables === connection?.id; disconnectProviderMutation.variables === provider.provider;
const connectionLabel = connection ? getConnectionLabel(connection) : null; const connectionLabel = connection ? getConnectionLabel(connection) : null;
const statusLabel = getStatusLabel(provider, connection, t); const statusLabel = getStatusLabel(provider, connection, t);
const unavailableReason = getProviderUnavailableReason(provider, t); const unavailableReason = getProviderUnavailableReason(provider, t);
@@ -209,7 +209,7 @@ function ChannelProviderItem({
type="button" type="button"
variant="outline" variant="outline"
size="sm" size="sm"
disabled={isConnecting} disabled={isConnecting || isDisconnecting}
onClick={() => setSetupOpen(true)} onClick={() => setSetupOpen(true)}
> >
{isConnecting ? ( {isConnecting ? (
@@ -220,13 +220,25 @@ function ChannelProviderItem({
{t.channels.modify} {t.channels.modify}
</Button> </Button>
) : null} ) : null}
{connection ? (
<Button <Button
type="button" type="button"
variant="outline" variant="outline"
size="sm" size="sm"
disabled={isDisconnecting} disabled={isDisconnecting}
onClick={() => disconnectMutation.mutate(connection.id)} onClick={() => {
void disconnectProviderMutation
.mutateAsync(provider.provider)
.then(() => {
toast.success(t.channels.revoked);
})
.catch((error) => {
toast.error(
error instanceof Error
? error.message
: t.channels.unavailable,
);
});
}}
> >
{isDisconnecting ? ( {isDisconnecting ? (
<LoaderCircleIcon className="animate-spin" /> <LoaderCircleIcon className="animate-spin" />
@@ -235,7 +247,6 @@ function ChannelProviderItem({
)} )}
{t.channels.disconnect} {t.channels.disconnect}
</Button> </Button>
) : null}
</> </>
) : ( ) : (
<Button <Button
+16
View File
@@ -99,3 +99,19 @@ export async function disconnectChannelConnection(
); );
} }
} }
export async function disconnectChannelProvider(
provider: ChannelProviderId,
): Promise<ChannelProvider> {
const response = await fetch(
channelsUrl(`/${encodeURIComponent(provider)}/runtime-config`),
{ method: "DELETE" },
);
if (!response.ok) {
await throwChannelApiError(
response,
`Failed to disconnect ${provider}: ${response.statusText}`,
);
}
return response.json() as Promise<ChannelProvider>;
}
+15
View File
@@ -4,6 +4,7 @@ import {
configureChannelProvider, configureChannelProvider,
connectChannelProvider, connectChannelProvider,
disconnectChannelConnection, disconnectChannelConnection,
disconnectChannelProvider,
listChannelConnections, listChannelConnections,
listChannelProviders, listChannelProviders,
} from "./api"; } from "./api";
@@ -79,3 +80,17 @@ export function useDisconnectChannelConnection() {
}, },
}); });
} }
export function useDisconnectChannelProvider() {
const queryClient = useQueryClient();
return useMutation({
mutationFn: (provider: ChannelProviderId) =>
disconnectChannelProvider(provider),
onSuccess: () => {
void queryClient.invalidateQueries({ queryKey: channelProviderQueryKey });
void queryClient.invalidateQueries({
queryKey: channelConnectionsQueryKey,
});
},
});
}
@@ -13,6 +13,7 @@ import {
configureChannelProvider, configureChannelProvider,
connectChannelProvider, connectChannelProvider,
disconnectChannelConnection, disconnectChannelConnection,
disconnectChannelProvider,
listChannelConnections, listChannelConnections,
listChannelProviders, listChannelProviders,
} from "@/core/channels/api"; } from "@/core/channels/api";
@@ -170,6 +171,30 @@ describe("channels api", () => {
); );
}); });
test("disconnects provider runtime configuration", async () => {
mockedFetch.mockResolvedValueOnce(
jsonResponse(200, {
provider: "slack",
display_name: "Slack",
enabled: true,
configured: false,
connectable: false,
auth_mode: "binding_code",
connection_status: "not_connected",
}),
);
await expect(disconnectChannelProvider("slack")).resolves.toMatchObject({
provider: "slack",
configured: false,
connection_status: "not_connected",
});
expect(mockedFetch).toHaveBeenCalledWith(
"/backend/api/channels/slack/runtime-config",
{ method: "DELETE" },
);
});
test("uses backend detail for failed requests", async () => { test("uses backend detail for failed requests", async () => {
mockedFetch.mockResolvedValueOnce( mockedFetch.mockResolvedValueOnce(
jsonResponse(400, { detail: "Channel provider is not configured" }), jsonResponse(400, { detail: "Channel provider is not configured" }),