mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 09:55:59 +00:00
Allow disconnecting runtime IM channels
This commit is contained in:
@@ -75,6 +75,14 @@ class ChannelRuntimeConfigStore:
|
||||
self._data[provider] = dict(config)
|
||||
self._save()
|
||||
|
||||
def remove_provider_config(self, provider: str) -> bool:
|
||||
with self._lock:
|
||||
if provider not in self._data:
|
||||
return False
|
||||
del self._data[provider]
|
||||
self._save()
|
||||
return True
|
||||
|
||||
|
||||
def _provider_enabled(channel_connections_config: Any, provider: str) -> bool:
|
||||
provider_config = getattr(channel_connections_config, provider, None)
|
||||
|
||||
@@ -186,6 +186,20 @@ class ChannelService:
|
||||
return True
|
||||
return await self.restart_channel(name)
|
||||
|
||||
async def remove_channel(self, name: str) -> bool:
|
||||
"""Remove runtime config for a channel and stop it if currently running."""
|
||||
self._config.pop(name, None)
|
||||
channel = self._channels.pop(name, None)
|
||||
if channel is None:
|
||||
return True
|
||||
try:
|
||||
await channel.stop()
|
||||
logger.info("Channel %s stopped and removed", name)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Error stopping channel %s for removal", name)
|
||||
return False
|
||||
|
||||
async def _start_channel(self, name: str, config: dict[str, Any]) -> bool:
|
||||
"""Instantiate and start a single channel."""
|
||||
import_path = _CHANNEL_REGISTRY.get(name)
|
||||
|
||||
@@ -160,16 +160,21 @@ def _get_channels_config(request: Request) -> dict[str, Any]:
|
||||
if isinstance(state_config, dict):
|
||||
return state_config
|
||||
|
||||
result = _load_channels_config(request, _get_channel_connections_config(request))
|
||||
request.app.state.channels_config = result
|
||||
return result
|
||||
|
||||
|
||||
def _load_channels_config(request: Request, config: ChannelConnectionsConfig) -> dict[str, Any]:
|
||||
app_config = _get_app_config()
|
||||
extra = app_config.model_extra or {}
|
||||
channels_config = extra.get("channels")
|
||||
result = dict(channels_config) if isinstance(channels_config, dict) else {}
|
||||
merge_runtime_channel_configs(
|
||||
result,
|
||||
_get_channel_connections_config(request),
|
||||
config,
|
||||
store=_get_runtime_config_store(request),
|
||||
)
|
||||
request.app.state.channels_config = result
|
||||
return result
|
||||
|
||||
|
||||
@@ -354,6 +359,23 @@ async def _restart_runtime_channel_if_available(provider: str, runtime_config: d
|
||||
return await service.configure_channel(provider, runtime_config)
|
||||
|
||||
|
||||
async def _sync_runtime_channel_after_removal(provider: str, channels_config: dict[str, Any]) -> bool | None:
|
||||
try:
|
||||
from app.channels.service import get_channel_service
|
||||
except Exception:
|
||||
logger.exception("Failed to import channel service while disconnecting %s", provider)
|
||||
return None
|
||||
|
||||
service = get_channel_service()
|
||||
if service is None:
|
||||
return None
|
||||
|
||||
runtime_config = channels_config.get(provider)
|
||||
if isinstance(runtime_config, dict) and runtime_config.get("enabled", False):
|
||||
return await service.configure_channel(provider, runtime_config)
|
||||
return await service.remove_channel(provider)
|
||||
|
||||
|
||||
@router.get("/providers", response_model=ChannelProvidersResponse)
|
||||
async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
||||
config = _get_channel_connections_config(request)
|
||||
@@ -404,6 +426,44 @@ async def disconnect_channel_connection(connection_id: str, request: Request) ->
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.delete("/{provider}/runtime-config", response_model=ChannelProviderResponse)
|
||||
async def disconnect_channel_provider_runtime(provider: str, request: Request) -> ChannelProviderResponse:
|
||||
config = _get_channel_connections_config(request)
|
||||
if not config.enabled:
|
||||
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||
|
||||
provider_config = _provider_config(config, provider)
|
||||
if not provider_config.enabled:
|
||||
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
||||
|
||||
owner_user_id = _get_user_id(request)
|
||||
try:
|
||||
repo = _get_repository(request, config)
|
||||
except HTTPException as exc:
|
||||
if exc.status_code != 503:
|
||||
raise
|
||||
repo = None
|
||||
|
||||
if repo is not None:
|
||||
for connection in await repo.list_connections(owner_user_id):
|
||||
if connection["provider"] == provider and connection["status"] != "revoked":
|
||||
await repo.disconnect_connection(
|
||||
connection_id=connection["id"],
|
||||
owner_user_id=owner_user_id,
|
||||
)
|
||||
|
||||
_get_runtime_config_store(request).remove_provider_config(provider)
|
||||
channels_config = _load_channels_config(request, config)
|
||||
request.app.state.channels_config = channels_config
|
||||
|
||||
stopped = await _sync_runtime_channel_after_removal(provider, channels_config)
|
||||
if stopped is False:
|
||||
display_name = _PROVIDER_META[provider]["display_name"]
|
||||
raise HTTPException(status_code=400, detail=f"Failed to stop {display_name} channel. Try again.")
|
||||
|
||||
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
|
||||
|
||||
|
||||
@router.post("/{provider}/connect", response_model=ChannelConnectResponse)
|
||||
async def connect_channel_provider(provider: str, request: Request) -> ChannelConnectResponse:
|
||||
config = _get_channel_connections_config(request)
|
||||
|
||||
@@ -459,6 +459,84 @@ def test_configure_provider_runtime_credentials_survive_local_restart(tmp_path):
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_disconnect_provider_runtime_config_clears_connected_state(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
config = ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"slack": {"enabled": True},
|
||||
}
|
||||
)
|
||||
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
|
||||
app = _make_app(config, repo, {}, runtime_config_store=runtime_config_store)
|
||||
|
||||
with TestClient(app) as client:
|
||||
configure_response = client.post(
|
||||
"/api/channels/slack/runtime-config",
|
||||
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
|
||||
)
|
||||
disconnect_response = client.delete("/api/channels/slack/runtime-config")
|
||||
providers_response = client.get("/api/channels/providers")
|
||||
|
||||
assert configure_response.status_code == 200
|
||||
assert disconnect_response.status_code == 200
|
||||
disconnected = disconnect_response.json()
|
||||
assert disconnected["provider"] == "slack"
|
||||
assert disconnected["configured"] is False
|
||||
assert disconnected["connectable"] is False
|
||||
assert disconnected["connection_status"] == "not_connected"
|
||||
assert runtime_config_store.get_provider_config("slack") is None
|
||||
|
||||
assert providers_response.status_code == 200
|
||||
by_provider = {item["provider"]: item for item in providers_response.json()["providers"]}
|
||||
assert by_provider["slack"]["connection_status"] == "not_connected"
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_disconnect_provider_runtime_config_revokes_current_user_provider_connections(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
|
||||
async def seed_connection():
|
||||
await repo.upsert_connection(
|
||||
owner_user_id=str(_user().id),
|
||||
provider="slack",
|
||||
external_account_id="U123",
|
||||
status="connected",
|
||||
)
|
||||
|
||||
anyio.run(seed_connection)
|
||||
config = ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"slack": {"enabled": True},
|
||||
}
|
||||
)
|
||||
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
|
||||
app = _make_app(config, repo, {}, runtime_config_store=runtime_config_store)
|
||||
|
||||
with TestClient(app) as client:
|
||||
configure_response = client.post(
|
||||
"/api/channels/slack/runtime-config",
|
||||
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
|
||||
)
|
||||
disconnect_response = client.delete("/api/channels/slack/runtime-config")
|
||||
|
||||
assert configure_response.status_code == 200
|
||||
assert disconnect_response.status_code == 200
|
||||
|
||||
async def get_connection_status():
|
||||
return (await repo.list_connections(str(_user().id)))[0]["status"]
|
||||
|
||||
assert anyio.run(get_connection_status) == "revoked"
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_disconnect_connection_revokes_current_user_connection(tmp_path):
|
||||
import anyio
|
||||
|
||||
|
||||
@@ -3400,6 +3400,31 @@ class TestChannelService:
|
||||
|
||||
assert service.manager._connection_repo is repo
|
||||
|
||||
def test_remove_channel_stops_running_channel_and_forgets_config(self):
|
||||
from app.channels.service import ChannelService
|
||||
|
||||
async def go():
|
||||
service = ChannelService(
|
||||
channels_config={
|
||||
"slack": {
|
||||
"enabled": True,
|
||||
"bot_token": "xoxb-ui",
|
||||
"app_token": "xapp-ui",
|
||||
},
|
||||
}
|
||||
)
|
||||
channel = AsyncMock()
|
||||
service._channels["slack"] = channel
|
||||
service._running = True
|
||||
|
||||
assert await service.remove_channel("slack") is True
|
||||
|
||||
channel.stop.assert_awaited_once()
|
||||
assert "slack" not in service._channels
|
||||
assert "slack" not in service._config
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_disabled_channel_with_string_creds_emits_warning(self, caplog):
|
||||
"""Warning is emitted when a channel has string credentials but enabled=false."""
|
||||
import logging
|
||||
|
||||
@@ -25,7 +25,7 @@ import {
|
||||
useChannelConnections,
|
||||
useChannelProviders,
|
||||
useConnectChannelProvider,
|
||||
useDisconnectChannelConnection,
|
||||
useDisconnectChannelProvider,
|
||||
} from "@/core/channels/hooks";
|
||||
import {
|
||||
closeConnectWindow,
|
||||
@@ -122,7 +122,7 @@ function ChannelProviderItem({
|
||||
const { t } = useI18n();
|
||||
const connectMutation = useConnectChannelProvider();
|
||||
const configureMutation = useConfigureChannelProvider();
|
||||
const disconnectMutation = useDisconnectChannelConnection();
|
||||
const disconnectProviderMutation = useDisconnectChannelProvider();
|
||||
const [setupOpen, setSetupOpen] = useState(false);
|
||||
const isConnected =
|
||||
connection?.status === "connected" ||
|
||||
@@ -137,8 +137,8 @@ function ChannelProviderItem({
|
||||
(configureMutation.isPending &&
|
||||
configureMutation.variables?.provider === provider.provider);
|
||||
const isDisconnecting =
|
||||
disconnectMutation.isPending &&
|
||||
disconnectMutation.variables === connection?.id;
|
||||
disconnectProviderMutation.isPending &&
|
||||
disconnectProviderMutation.variables === provider.provider;
|
||||
const connectionLabel = connection ? getConnectionLabel(connection) : null;
|
||||
const statusLabel = getStatusLabel(provider, connection, t);
|
||||
const unavailableReason = getProviderUnavailableReason(provider, t);
|
||||
@@ -209,7 +209,7 @@ function ChannelProviderItem({
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
disabled={isConnecting}
|
||||
disabled={isConnecting || isDisconnecting}
|
||||
onClick={() => setSetupOpen(true)}
|
||||
>
|
||||
{isConnecting ? (
|
||||
@@ -220,22 +220,33 @@ function ChannelProviderItem({
|
||||
{t.channels.modify}
|
||||
</Button>
|
||||
) : null}
|
||||
{connection ? (
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
disabled={isDisconnecting}
|
||||
onClick={() => disconnectMutation.mutate(connection.id)}
|
||||
>
|
||||
{isDisconnecting ? (
|
||||
<LoaderCircleIcon className="animate-spin" />
|
||||
) : (
|
||||
<UnplugIcon />
|
||||
)}
|
||||
{t.channels.disconnect}
|
||||
</Button>
|
||||
) : null}
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
disabled={isDisconnecting}
|
||||
onClick={() => {
|
||||
void disconnectProviderMutation
|
||||
.mutateAsync(provider.provider)
|
||||
.then(() => {
|
||||
toast.success(t.channels.revoked);
|
||||
})
|
||||
.catch((error) => {
|
||||
toast.error(
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: t.channels.unavailable,
|
||||
);
|
||||
});
|
||||
}}
|
||||
>
|
||||
{isDisconnecting ? (
|
||||
<LoaderCircleIcon className="animate-spin" />
|
||||
) : (
|
||||
<UnplugIcon />
|
||||
)}
|
||||
{t.channels.disconnect}
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<Button
|
||||
|
||||
@@ -99,3 +99,19 @@ export async function disconnectChannelConnection(
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export async function disconnectChannelProvider(
|
||||
provider: ChannelProviderId,
|
||||
): Promise<ChannelProvider> {
|
||||
const response = await fetch(
|
||||
channelsUrl(`/${encodeURIComponent(provider)}/runtime-config`),
|
||||
{ method: "DELETE" },
|
||||
);
|
||||
if (!response.ok) {
|
||||
await throwChannelApiError(
|
||||
response,
|
||||
`Failed to disconnect ${provider}: ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
return response.json() as Promise<ChannelProvider>;
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import {
|
||||
configureChannelProvider,
|
||||
connectChannelProvider,
|
||||
disconnectChannelConnection,
|
||||
disconnectChannelProvider,
|
||||
listChannelConnections,
|
||||
listChannelProviders,
|
||||
} from "./api";
|
||||
@@ -79,3 +80,17 @@ export function useDisconnectChannelConnection() {
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function useDisconnectChannelProvider() {
|
||||
const queryClient = useQueryClient();
|
||||
return useMutation({
|
||||
mutationFn: (provider: ChannelProviderId) =>
|
||||
disconnectChannelProvider(provider),
|
||||
onSuccess: () => {
|
||||
void queryClient.invalidateQueries({ queryKey: channelProviderQueryKey });
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: channelConnectionsQueryKey,
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
configureChannelProvider,
|
||||
connectChannelProvider,
|
||||
disconnectChannelConnection,
|
||||
disconnectChannelProvider,
|
||||
listChannelConnections,
|
||||
listChannelProviders,
|
||||
} from "@/core/channels/api";
|
||||
@@ -170,6 +171,30 @@ describe("channels api", () => {
|
||||
);
|
||||
});
|
||||
|
||||
test("disconnects provider runtime configuration", async () => {
|
||||
mockedFetch.mockResolvedValueOnce(
|
||||
jsonResponse(200, {
|
||||
provider: "slack",
|
||||
display_name: "Slack",
|
||||
enabled: true,
|
||||
configured: false,
|
||||
connectable: false,
|
||||
auth_mode: "binding_code",
|
||||
connection_status: "not_connected",
|
||||
}),
|
||||
);
|
||||
|
||||
await expect(disconnectChannelProvider("slack")).resolves.toMatchObject({
|
||||
provider: "slack",
|
||||
configured: false,
|
||||
connection_status: "not_connected",
|
||||
});
|
||||
expect(mockedFetch).toHaveBeenCalledWith(
|
||||
"/backend/api/channels/slack/runtime-config",
|
||||
{ method: "DELETE" },
|
||||
);
|
||||
});
|
||||
|
||||
test("uses backend detail for failed requests", async () => {
|
||||
mockedFetch.mockResolvedValueOnce(
|
||||
jsonResponse(400, { detail: "Channel provider is not configured" }),
|
||||
|
||||
Reference in New Issue
Block a user