mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 09:55:59 +00:00
Allow disconnecting runtime IM channels
This commit is contained in:
@@ -75,6 +75,14 @@ class ChannelRuntimeConfigStore:
|
|||||||
self._data[provider] = dict(config)
|
self._data[provider] = dict(config)
|
||||||
self._save()
|
self._save()
|
||||||
|
|
||||||
|
def remove_provider_config(self, provider: str) -> bool:
|
||||||
|
with self._lock:
|
||||||
|
if provider not in self._data:
|
||||||
|
return False
|
||||||
|
del self._data[provider]
|
||||||
|
self._save()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _provider_enabled(channel_connections_config: Any, provider: str) -> bool:
|
def _provider_enabled(channel_connections_config: Any, provider: str) -> bool:
|
||||||
provider_config = getattr(channel_connections_config, provider, None)
|
provider_config = getattr(channel_connections_config, provider, None)
|
||||||
|
|||||||
@@ -186,6 +186,20 @@ class ChannelService:
|
|||||||
return True
|
return True
|
||||||
return await self.restart_channel(name)
|
return await self.restart_channel(name)
|
||||||
|
|
||||||
|
async def remove_channel(self, name: str) -> bool:
|
||||||
|
"""Remove runtime config for a channel and stop it if currently running."""
|
||||||
|
self._config.pop(name, None)
|
||||||
|
channel = self._channels.pop(name, None)
|
||||||
|
if channel is None:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
await channel.stop()
|
||||||
|
logger.info("Channel %s stopped and removed", name)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error stopping channel %s for removal", name)
|
||||||
|
return False
|
||||||
|
|
||||||
async def _start_channel(self, name: str, config: dict[str, Any]) -> bool:
|
async def _start_channel(self, name: str, config: dict[str, Any]) -> bool:
|
||||||
"""Instantiate and start a single channel."""
|
"""Instantiate and start a single channel."""
|
||||||
import_path = _CHANNEL_REGISTRY.get(name)
|
import_path = _CHANNEL_REGISTRY.get(name)
|
||||||
|
|||||||
@@ -160,16 +160,21 @@ def _get_channels_config(request: Request) -> dict[str, Any]:
|
|||||||
if isinstance(state_config, dict):
|
if isinstance(state_config, dict):
|
||||||
return state_config
|
return state_config
|
||||||
|
|
||||||
|
result = _load_channels_config(request, _get_channel_connections_config(request))
|
||||||
|
request.app.state.channels_config = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _load_channels_config(request: Request, config: ChannelConnectionsConfig) -> dict[str, Any]:
|
||||||
app_config = _get_app_config()
|
app_config = _get_app_config()
|
||||||
extra = app_config.model_extra or {}
|
extra = app_config.model_extra or {}
|
||||||
channels_config = extra.get("channels")
|
channels_config = extra.get("channels")
|
||||||
result = dict(channels_config) if isinstance(channels_config, dict) else {}
|
result = dict(channels_config) if isinstance(channels_config, dict) else {}
|
||||||
merge_runtime_channel_configs(
|
merge_runtime_channel_configs(
|
||||||
result,
|
result,
|
||||||
_get_channel_connections_config(request),
|
config,
|
||||||
store=_get_runtime_config_store(request),
|
store=_get_runtime_config_store(request),
|
||||||
)
|
)
|
||||||
request.app.state.channels_config = result
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -354,6 +359,23 @@ async def _restart_runtime_channel_if_available(provider: str, runtime_config: d
|
|||||||
return await service.configure_channel(provider, runtime_config)
|
return await service.configure_channel(provider, runtime_config)
|
||||||
|
|
||||||
|
|
||||||
|
async def _sync_runtime_channel_after_removal(provider: str, channels_config: dict[str, Any]) -> bool | None:
|
||||||
|
try:
|
||||||
|
from app.channels.service import get_channel_service
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to import channel service while disconnecting %s", provider)
|
||||||
|
return None
|
||||||
|
|
||||||
|
service = get_channel_service()
|
||||||
|
if service is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
runtime_config = channels_config.get(provider)
|
||||||
|
if isinstance(runtime_config, dict) and runtime_config.get("enabled", False):
|
||||||
|
return await service.configure_channel(provider, runtime_config)
|
||||||
|
return await service.remove_channel(provider)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/providers", response_model=ChannelProvidersResponse)
|
@router.get("/providers", response_model=ChannelProvidersResponse)
|
||||||
async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
||||||
config = _get_channel_connections_config(request)
|
config = _get_channel_connections_config(request)
|
||||||
@@ -404,6 +426,44 @@ async def disconnect_channel_connection(connection_id: str, request: Request) ->
|
|||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{provider}/runtime-config", response_model=ChannelProviderResponse)
|
||||||
|
async def disconnect_channel_provider_runtime(provider: str, request: Request) -> ChannelProviderResponse:
|
||||||
|
config = _get_channel_connections_config(request)
|
||||||
|
if not config.enabled:
|
||||||
|
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||||
|
|
||||||
|
provider_config = _provider_config(config, provider)
|
||||||
|
if not provider_config.enabled:
|
||||||
|
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
||||||
|
|
||||||
|
owner_user_id = _get_user_id(request)
|
||||||
|
try:
|
||||||
|
repo = _get_repository(request, config)
|
||||||
|
except HTTPException as exc:
|
||||||
|
if exc.status_code != 503:
|
||||||
|
raise
|
||||||
|
repo = None
|
||||||
|
|
||||||
|
if repo is not None:
|
||||||
|
for connection in await repo.list_connections(owner_user_id):
|
||||||
|
if connection["provider"] == provider and connection["status"] != "revoked":
|
||||||
|
await repo.disconnect_connection(
|
||||||
|
connection_id=connection["id"],
|
||||||
|
owner_user_id=owner_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
_get_runtime_config_store(request).remove_provider_config(provider)
|
||||||
|
channels_config = _load_channels_config(request, config)
|
||||||
|
request.app.state.channels_config = channels_config
|
||||||
|
|
||||||
|
stopped = await _sync_runtime_channel_after_removal(provider, channels_config)
|
||||||
|
if stopped is False:
|
||||||
|
display_name = _PROVIDER_META[provider]["display_name"]
|
||||||
|
raise HTTPException(status_code=400, detail=f"Failed to stop {display_name} channel. Try again.")
|
||||||
|
|
||||||
|
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{provider}/connect", response_model=ChannelConnectResponse)
|
@router.post("/{provider}/connect", response_model=ChannelConnectResponse)
|
||||||
async def connect_channel_provider(provider: str, request: Request) -> ChannelConnectResponse:
|
async def connect_channel_provider(provider: str, request: Request) -> ChannelConnectResponse:
|
||||||
config = _get_channel_connections_config(request)
|
config = _get_channel_connections_config(request)
|
||||||
|
|||||||
@@ -459,6 +459,84 @@ def test_configure_provider_runtime_credentials_survive_local_restart(tmp_path):
|
|||||||
anyio.run(repo.close)
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_disconnect_provider_runtime_config_clears_connected_state(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
config = ChannelConnectionsConfig.model_validate(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"slack": {"enabled": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
|
||||||
|
app = _make_app(config, repo, {}, runtime_config_store=runtime_config_store)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
configure_response = client.post(
|
||||||
|
"/api/channels/slack/runtime-config",
|
||||||
|
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
|
||||||
|
)
|
||||||
|
disconnect_response = client.delete("/api/channels/slack/runtime-config")
|
||||||
|
providers_response = client.get("/api/channels/providers")
|
||||||
|
|
||||||
|
assert configure_response.status_code == 200
|
||||||
|
assert disconnect_response.status_code == 200
|
||||||
|
disconnected = disconnect_response.json()
|
||||||
|
assert disconnected["provider"] == "slack"
|
||||||
|
assert disconnected["configured"] is False
|
||||||
|
assert disconnected["connectable"] is False
|
||||||
|
assert disconnected["connection_status"] == "not_connected"
|
||||||
|
assert runtime_config_store.get_provider_config("slack") is None
|
||||||
|
|
||||||
|
assert providers_response.status_code == 200
|
||||||
|
by_provider = {item["provider"]: item for item in providers_response.json()["providers"]}
|
||||||
|
assert by_provider["slack"]["connection_status"] == "not_connected"
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_disconnect_provider_runtime_config_revokes_current_user_provider_connections(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
|
||||||
|
async def seed_connection():
|
||||||
|
await repo.upsert_connection(
|
||||||
|
owner_user_id=str(_user().id),
|
||||||
|
provider="slack",
|
||||||
|
external_account_id="U123",
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
|
||||||
|
anyio.run(seed_connection)
|
||||||
|
config = ChannelConnectionsConfig.model_validate(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"slack": {"enabled": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
|
||||||
|
app = _make_app(config, repo, {}, runtime_config_store=runtime_config_store)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
configure_response = client.post(
|
||||||
|
"/api/channels/slack/runtime-config",
|
||||||
|
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
|
||||||
|
)
|
||||||
|
disconnect_response = client.delete("/api/channels/slack/runtime-config")
|
||||||
|
|
||||||
|
assert configure_response.status_code == 200
|
||||||
|
assert disconnect_response.status_code == 200
|
||||||
|
|
||||||
|
async def get_connection_status():
|
||||||
|
return (await repo.list_connections(str(_user().id)))[0]["status"]
|
||||||
|
|
||||||
|
assert anyio.run(get_connection_status) == "revoked"
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
def test_disconnect_connection_revokes_current_user_connection(tmp_path):
|
def test_disconnect_connection_revokes_current_user_connection(tmp_path):
|
||||||
import anyio
|
import anyio
|
||||||
|
|
||||||
|
|||||||
@@ -3400,6 +3400,31 @@ class TestChannelService:
|
|||||||
|
|
||||||
assert service.manager._connection_repo is repo
|
assert service.manager._connection_repo is repo
|
||||||
|
|
||||||
|
def test_remove_channel_stops_running_channel_and_forgets_config(self):
|
||||||
|
from app.channels.service import ChannelService
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
service = ChannelService(
|
||||||
|
channels_config={
|
||||||
|
"slack": {
|
||||||
|
"enabled": True,
|
||||||
|
"bot_token": "xoxb-ui",
|
||||||
|
"app_token": "xapp-ui",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
channel = AsyncMock()
|
||||||
|
service._channels["slack"] = channel
|
||||||
|
service._running = True
|
||||||
|
|
||||||
|
assert await service.remove_channel("slack") is True
|
||||||
|
|
||||||
|
channel.stop.assert_awaited_once()
|
||||||
|
assert "slack" not in service._channels
|
||||||
|
assert "slack" not in service._config
|
||||||
|
|
||||||
|
_run(go())
|
||||||
|
|
||||||
def test_disabled_channel_with_string_creds_emits_warning(self, caplog):
|
def test_disabled_channel_with_string_creds_emits_warning(self, caplog):
|
||||||
"""Warning is emitted when a channel has string credentials but enabled=false."""
|
"""Warning is emitted when a channel has string credentials but enabled=false."""
|
||||||
import logging
|
import logging
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import {
|
|||||||
useChannelConnections,
|
useChannelConnections,
|
||||||
useChannelProviders,
|
useChannelProviders,
|
||||||
useConnectChannelProvider,
|
useConnectChannelProvider,
|
||||||
useDisconnectChannelConnection,
|
useDisconnectChannelProvider,
|
||||||
} from "@/core/channels/hooks";
|
} from "@/core/channels/hooks";
|
||||||
import {
|
import {
|
||||||
closeConnectWindow,
|
closeConnectWindow,
|
||||||
@@ -122,7 +122,7 @@ function ChannelProviderItem({
|
|||||||
const { t } = useI18n();
|
const { t } = useI18n();
|
||||||
const connectMutation = useConnectChannelProvider();
|
const connectMutation = useConnectChannelProvider();
|
||||||
const configureMutation = useConfigureChannelProvider();
|
const configureMutation = useConfigureChannelProvider();
|
||||||
const disconnectMutation = useDisconnectChannelConnection();
|
const disconnectProviderMutation = useDisconnectChannelProvider();
|
||||||
const [setupOpen, setSetupOpen] = useState(false);
|
const [setupOpen, setSetupOpen] = useState(false);
|
||||||
const isConnected =
|
const isConnected =
|
||||||
connection?.status === "connected" ||
|
connection?.status === "connected" ||
|
||||||
@@ -137,8 +137,8 @@ function ChannelProviderItem({
|
|||||||
(configureMutation.isPending &&
|
(configureMutation.isPending &&
|
||||||
configureMutation.variables?.provider === provider.provider);
|
configureMutation.variables?.provider === provider.provider);
|
||||||
const isDisconnecting =
|
const isDisconnecting =
|
||||||
disconnectMutation.isPending &&
|
disconnectProviderMutation.isPending &&
|
||||||
disconnectMutation.variables === connection?.id;
|
disconnectProviderMutation.variables === provider.provider;
|
||||||
const connectionLabel = connection ? getConnectionLabel(connection) : null;
|
const connectionLabel = connection ? getConnectionLabel(connection) : null;
|
||||||
const statusLabel = getStatusLabel(provider, connection, t);
|
const statusLabel = getStatusLabel(provider, connection, t);
|
||||||
const unavailableReason = getProviderUnavailableReason(provider, t);
|
const unavailableReason = getProviderUnavailableReason(provider, t);
|
||||||
@@ -209,7 +209,7 @@ function ChannelProviderItem({
|
|||||||
type="button"
|
type="button"
|
||||||
variant="outline"
|
variant="outline"
|
||||||
size="sm"
|
size="sm"
|
||||||
disabled={isConnecting}
|
disabled={isConnecting || isDisconnecting}
|
||||||
onClick={() => setSetupOpen(true)}
|
onClick={() => setSetupOpen(true)}
|
||||||
>
|
>
|
||||||
{isConnecting ? (
|
{isConnecting ? (
|
||||||
@@ -220,13 +220,25 @@ function ChannelProviderItem({
|
|||||||
{t.channels.modify}
|
{t.channels.modify}
|
||||||
</Button>
|
</Button>
|
||||||
) : null}
|
) : null}
|
||||||
{connection ? (
|
|
||||||
<Button
|
<Button
|
||||||
type="button"
|
type="button"
|
||||||
variant="outline"
|
variant="outline"
|
||||||
size="sm"
|
size="sm"
|
||||||
disabled={isDisconnecting}
|
disabled={isDisconnecting}
|
||||||
onClick={() => disconnectMutation.mutate(connection.id)}
|
onClick={() => {
|
||||||
|
void disconnectProviderMutation
|
||||||
|
.mutateAsync(provider.provider)
|
||||||
|
.then(() => {
|
||||||
|
toast.success(t.channels.revoked);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
toast.error(
|
||||||
|
error instanceof Error
|
||||||
|
? error.message
|
||||||
|
: t.channels.unavailable,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}}
|
||||||
>
|
>
|
||||||
{isDisconnecting ? (
|
{isDisconnecting ? (
|
||||||
<LoaderCircleIcon className="animate-spin" />
|
<LoaderCircleIcon className="animate-spin" />
|
||||||
@@ -235,7 +247,6 @@ function ChannelProviderItem({
|
|||||||
)}
|
)}
|
||||||
{t.channels.disconnect}
|
{t.channels.disconnect}
|
||||||
</Button>
|
</Button>
|
||||||
) : null}
|
|
||||||
</>
|
</>
|
||||||
) : (
|
) : (
|
||||||
<Button
|
<Button
|
||||||
|
|||||||
@@ -99,3 +99,19 @@ export async function disconnectChannelConnection(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function disconnectChannelProvider(
|
||||||
|
provider: ChannelProviderId,
|
||||||
|
): Promise<ChannelProvider> {
|
||||||
|
const response = await fetch(
|
||||||
|
channelsUrl(`/${encodeURIComponent(provider)}/runtime-config`),
|
||||||
|
{ method: "DELETE" },
|
||||||
|
);
|
||||||
|
if (!response.ok) {
|
||||||
|
await throwChannelApiError(
|
||||||
|
response,
|
||||||
|
`Failed to disconnect ${provider}: ${response.statusText}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return response.json() as Promise<ChannelProvider>;
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import {
|
|||||||
configureChannelProvider,
|
configureChannelProvider,
|
||||||
connectChannelProvider,
|
connectChannelProvider,
|
||||||
disconnectChannelConnection,
|
disconnectChannelConnection,
|
||||||
|
disconnectChannelProvider,
|
||||||
listChannelConnections,
|
listChannelConnections,
|
||||||
listChannelProviders,
|
listChannelProviders,
|
||||||
} from "./api";
|
} from "./api";
|
||||||
@@ -79,3 +80,17 @@ export function useDisconnectChannelConnection() {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function useDisconnectChannelProvider() {
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
return useMutation({
|
||||||
|
mutationFn: (provider: ChannelProviderId) =>
|
||||||
|
disconnectChannelProvider(provider),
|
||||||
|
onSuccess: () => {
|
||||||
|
void queryClient.invalidateQueries({ queryKey: channelProviderQueryKey });
|
||||||
|
void queryClient.invalidateQueries({
|
||||||
|
queryKey: channelConnectionsQueryKey,
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import {
|
|||||||
configureChannelProvider,
|
configureChannelProvider,
|
||||||
connectChannelProvider,
|
connectChannelProvider,
|
||||||
disconnectChannelConnection,
|
disconnectChannelConnection,
|
||||||
|
disconnectChannelProvider,
|
||||||
listChannelConnections,
|
listChannelConnections,
|
||||||
listChannelProviders,
|
listChannelProviders,
|
||||||
} from "@/core/channels/api";
|
} from "@/core/channels/api";
|
||||||
@@ -170,6 +171,30 @@ describe("channels api", () => {
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test("disconnects provider runtime configuration", async () => {
|
||||||
|
mockedFetch.mockResolvedValueOnce(
|
||||||
|
jsonResponse(200, {
|
||||||
|
provider: "slack",
|
||||||
|
display_name: "Slack",
|
||||||
|
enabled: true,
|
||||||
|
configured: false,
|
||||||
|
connectable: false,
|
||||||
|
auth_mode: "binding_code",
|
||||||
|
connection_status: "not_connected",
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
await expect(disconnectChannelProvider("slack")).resolves.toMatchObject({
|
||||||
|
provider: "slack",
|
||||||
|
configured: false,
|
||||||
|
connection_status: "not_connected",
|
||||||
|
});
|
||||||
|
expect(mockedFetch).toHaveBeenCalledWith(
|
||||||
|
"/backend/api/channels/slack/runtime-config",
|
||||||
|
{ method: "DELETE" },
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
test("uses backend detail for failed requests", async () => {
|
test("uses backend detail for failed requests", async () => {
|
||||||
mockedFetch.mockResolvedValueOnce(
|
mockedFetch.mockResolvedValueOnce(
|
||||||
jsonResponse(400, { detail: "Channel provider is not configured" }),
|
jsonResponse(400, { detail: "Channel provider is not configured" }),
|
||||||
|
|||||||
Reference in New Issue
Block a user