mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 18:05:58 +00:00
Persist disconnected IM channel state
This commit is contained in:
@@ -11,6 +11,8 @@ from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RUNTIME_CHANNEL_DISABLED_FLAG = "_runtime_disabled"
|
||||
|
||||
|
||||
class ChannelRuntimeConfigStore:
|
||||
"""JSON-backed store for channel credentials entered from the UI.
|
||||
@@ -75,6 +77,14 @@ class ChannelRuntimeConfigStore:
|
||||
self._data[provider] = dict(config)
|
||||
self._save()
|
||||
|
||||
def set_provider_disconnected(self, provider: str) -> None:
|
||||
with self._lock:
|
||||
self._data[provider] = {
|
||||
"enabled": False,
|
||||
RUNTIME_CHANNEL_DISABLED_FLAG: True,
|
||||
}
|
||||
self._save()
|
||||
|
||||
def remove_provider_config(self, provider: str) -> bool:
|
||||
with self._lock:
|
||||
if provider not in self._data:
|
||||
@@ -89,6 +99,10 @@ def _provider_enabled(channel_connections_config: Any, provider: str) -> bool:
|
||||
return bool(getattr(provider_config, "enabled", False))
|
||||
|
||||
|
||||
def _runtime_channel_disconnected(runtime_config: dict[str, Any]) -> bool:
|
||||
return runtime_config.get(RUNTIME_CHANNEL_DISABLED_FLAG) is True and runtime_config.get("enabled") is False
|
||||
|
||||
|
||||
def merge_runtime_channel_configs(
|
||||
channels_config: dict[str, Any],
|
||||
channel_connections_config: Any,
|
||||
@@ -103,6 +117,9 @@ def merge_runtime_channel_configs(
|
||||
for provider, runtime_config in runtime_store.load_all().items():
|
||||
if not _provider_enabled(channel_connections_config, provider):
|
||||
continue
|
||||
if _runtime_channel_disconnected(runtime_config):
|
||||
channels_config.pop(provider, None)
|
||||
continue
|
||||
existing = channels_config.get(provider)
|
||||
merged = dict(runtime_config)
|
||||
if isinstance(existing, dict):
|
||||
|
||||
@@ -547,7 +547,7 @@ async def disconnect_channel_provider_runtime(provider: str, request: Request) -
|
||||
owner_user_id=owner_user_id,
|
||||
)
|
||||
|
||||
_get_runtime_config_store(request).remove_provider_config(provider)
|
||||
_get_runtime_config_store(request).set_provider_disconnected(provider)
|
||||
channels_config = _load_channels_config(request, config)
|
||||
request.app.state.channels_config = channels_config
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from tempfile import TemporaryDirectory
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
@@ -665,7 +666,10 @@ def test_disconnect_provider_runtime_config_clears_connected_state(tmp_path):
|
||||
assert disconnected["configured"] is False
|
||||
assert disconnected["connectable"] is False
|
||||
assert disconnected["connection_status"] == "not_connected"
|
||||
assert runtime_config_store.get_provider_config("slack") is None
|
||||
assert runtime_config_store.get_provider_config("slack") == {
|
||||
"enabled": False,
|
||||
"_runtime_disabled": True,
|
||||
}
|
||||
|
||||
assert providers_response.status_code == 200
|
||||
by_provider = {item["provider"]: item for item in providers_response.json()["providers"]}
|
||||
@@ -674,6 +678,79 @@ def test_disconnect_provider_runtime_config_clears_connected_state(tmp_path):
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_disconnect_provider_runtime_config_suppresses_file_config_and_stops_channel(tmp_path, monkeypatch):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
config = ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"feishu": {"enabled": True},
|
||||
}
|
||||
)
|
||||
set_app_config(
|
||||
AppConfig.model_validate(
|
||||
{
|
||||
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
|
||||
"channels": {
|
||||
"feishu": {
|
||||
"enabled": True,
|
||||
"app_id": "file-app-id",
|
||||
"app_secret": "file-secret",
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
|
||||
runtime_config_store.set_provider_config(
|
||||
"feishu",
|
||||
{
|
||||
"enabled": True,
|
||||
"app_id": "runtime-app-id",
|
||||
"app_secret": "runtime-secret",
|
||||
},
|
||||
)
|
||||
service = SimpleNamespace(
|
||||
configure_channel=AsyncMock(return_value=True),
|
||||
remove_channel=AsyncMock(return_value=True),
|
||||
)
|
||||
monkeypatch.setattr("app.channels.service.get_channel_service", lambda: service)
|
||||
app = _make_app(
|
||||
config,
|
||||
repo,
|
||||
{
|
||||
"feishu": {
|
||||
"enabled": True,
|
||||
"app_id": "runtime-app-id",
|
||||
"app_secret": "runtime-secret",
|
||||
}
|
||||
},
|
||||
runtime_config_store=runtime_config_store,
|
||||
)
|
||||
|
||||
with TestClient(app) as client:
|
||||
disconnect_response = client.delete("/api/channels/feishu/runtime-config")
|
||||
providers_response = client.get("/api/channels/providers")
|
||||
|
||||
assert disconnect_response.status_code == 200
|
||||
disconnected = disconnect_response.json()
|
||||
assert disconnected["provider"] == "feishu"
|
||||
assert disconnected["configured"] is False
|
||||
assert disconnected["connectable"] is False
|
||||
assert disconnected["connection_status"] == "not_connected"
|
||||
assert "feishu" not in app.state.channels_config
|
||||
service.remove_channel.assert_awaited_once_with("feishu")
|
||||
service.configure_channel.assert_not_awaited()
|
||||
|
||||
assert providers_response.status_code == 200
|
||||
by_provider = {item["provider"]: item for item in providers_response.json()["providers"]}
|
||||
assert by_provider["feishu"]["configured"] is False
|
||||
assert by_provider["feishu"]["connection_status"] == "not_connected"
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_disconnect_provider_runtime_config_revokes_current_user_provider_connections(tmp_path):
|
||||
import anyio
|
||||
|
||||
|
||||
@@ -3504,6 +3504,43 @@ class TestChannelService:
|
||||
"app_token": "xapp-ui",
|
||||
}
|
||||
|
||||
def test_from_app_config_runtime_disconnect_suppresses_file_channel_config(self, monkeypatch, tmp_path):
|
||||
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
|
||||
from app.channels.service import ChannelService
|
||||
from deerflow.config import paths as paths_module
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
|
||||
monkeypatch.setattr(paths_module, "_paths", None)
|
||||
ChannelRuntimeConfigStore().set_provider_config(
|
||||
"feishu",
|
||||
{
|
||||
"enabled": False,
|
||||
"_runtime_disabled": True,
|
||||
},
|
||||
)
|
||||
app_config = SimpleNamespace(
|
||||
model_extra={
|
||||
"channels": {
|
||||
"feishu": {
|
||||
"enabled": True,
|
||||
"app_id": "file-app-id",
|
||||
"app_secret": "file-secret",
|
||||
}
|
||||
}
|
||||
},
|
||||
channel_connections=ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"feishu": {"enabled": True},
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
service = ChannelService.from_app_config(app_config)
|
||||
|
||||
assert "feishu" not in service._config
|
||||
|
||||
def test_start_retries_configured_channel_until_ready(self, monkeypatch):
|
||||
from app.channels.service import ChannelService
|
||||
|
||||
|
||||
Reference in New Issue
Block a user