mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 18:05:58 +00:00
Persist disconnected IM channel state
This commit is contained in:
@@ -11,6 +11,8 @@ from typing import Any
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
RUNTIME_CHANNEL_DISABLED_FLAG = "_runtime_disabled"
|
||||||
|
|
||||||
|
|
||||||
class ChannelRuntimeConfigStore:
|
class ChannelRuntimeConfigStore:
|
||||||
"""JSON-backed store for channel credentials entered from the UI.
|
"""JSON-backed store for channel credentials entered from the UI.
|
||||||
@@ -75,6 +77,14 @@ class ChannelRuntimeConfigStore:
|
|||||||
self._data[provider] = dict(config)
|
self._data[provider] = dict(config)
|
||||||
self._save()
|
self._save()
|
||||||
|
|
||||||
|
def set_provider_disconnected(self, provider: str) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._data[provider] = {
|
||||||
|
"enabled": False,
|
||||||
|
RUNTIME_CHANNEL_DISABLED_FLAG: True,
|
||||||
|
}
|
||||||
|
self._save()
|
||||||
|
|
||||||
def remove_provider_config(self, provider: str) -> bool:
|
def remove_provider_config(self, provider: str) -> bool:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if provider not in self._data:
|
if provider not in self._data:
|
||||||
@@ -89,6 +99,10 @@ def _provider_enabled(channel_connections_config: Any, provider: str) -> bool:
|
|||||||
return bool(getattr(provider_config, "enabled", False))
|
return bool(getattr(provider_config, "enabled", False))
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime_channel_disconnected(runtime_config: dict[str, Any]) -> bool:
|
||||||
|
return runtime_config.get(RUNTIME_CHANNEL_DISABLED_FLAG) is True and runtime_config.get("enabled") is False
|
||||||
|
|
||||||
|
|
||||||
def merge_runtime_channel_configs(
|
def merge_runtime_channel_configs(
|
||||||
channels_config: dict[str, Any],
|
channels_config: dict[str, Any],
|
||||||
channel_connections_config: Any,
|
channel_connections_config: Any,
|
||||||
@@ -103,6 +117,9 @@ def merge_runtime_channel_configs(
|
|||||||
for provider, runtime_config in runtime_store.load_all().items():
|
for provider, runtime_config in runtime_store.load_all().items():
|
||||||
if not _provider_enabled(channel_connections_config, provider):
|
if not _provider_enabled(channel_connections_config, provider):
|
||||||
continue
|
continue
|
||||||
|
if _runtime_channel_disconnected(runtime_config):
|
||||||
|
channels_config.pop(provider, None)
|
||||||
|
continue
|
||||||
existing = channels_config.get(provider)
|
existing = channels_config.get(provider)
|
||||||
merged = dict(runtime_config)
|
merged = dict(runtime_config)
|
||||||
if isinstance(existing, dict):
|
if isinstance(existing, dict):
|
||||||
|
|||||||
@@ -547,7 +547,7 @@ async def disconnect_channel_provider_runtime(provider: str, request: Request) -
|
|||||||
owner_user_id=owner_user_id,
|
owner_user_id=owner_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
_get_runtime_config_store(request).remove_provider_config(provider)
|
_get_runtime_config_store(request).set_provider_disconnected(provider)
|
||||||
channels_config = _load_channels_config(request, config)
|
channels_config = _load_channels_config(request, config)
|
||||||
request.app.state.channels_config = channels_config
|
request.app.state.channels_config = channels_config
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -665,7 +666,10 @@ def test_disconnect_provider_runtime_config_clears_connected_state(tmp_path):
|
|||||||
assert disconnected["configured"] is False
|
assert disconnected["configured"] is False
|
||||||
assert disconnected["connectable"] is False
|
assert disconnected["connectable"] is False
|
||||||
assert disconnected["connection_status"] == "not_connected"
|
assert disconnected["connection_status"] == "not_connected"
|
||||||
assert runtime_config_store.get_provider_config("slack") is None
|
assert runtime_config_store.get_provider_config("slack") == {
|
||||||
|
"enabled": False,
|
||||||
|
"_runtime_disabled": True,
|
||||||
|
}
|
||||||
|
|
||||||
assert providers_response.status_code == 200
|
assert providers_response.status_code == 200
|
||||||
by_provider = {item["provider"]: item for item in providers_response.json()["providers"]}
|
by_provider = {item["provider"]: item for item in providers_response.json()["providers"]}
|
||||||
@@ -674,6 +678,79 @@ def test_disconnect_provider_runtime_config_clears_connected_state(tmp_path):
|
|||||||
anyio.run(repo.close)
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_disconnect_provider_runtime_config_suppresses_file_config_and_stops_channel(tmp_path, monkeypatch):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
config = ChannelConnectionsConfig.model_validate(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"feishu": {"enabled": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
set_app_config(
|
||||||
|
AppConfig.model_validate(
|
||||||
|
{
|
||||||
|
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
|
||||||
|
"channels": {
|
||||||
|
"feishu": {
|
||||||
|
"enabled": True,
|
||||||
|
"app_id": "file-app-id",
|
||||||
|
"app_secret": "file-secret",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
|
||||||
|
runtime_config_store.set_provider_config(
|
||||||
|
"feishu",
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"app_id": "runtime-app-id",
|
||||||
|
"app_secret": "runtime-secret",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
service = SimpleNamespace(
|
||||||
|
configure_channel=AsyncMock(return_value=True),
|
||||||
|
remove_channel=AsyncMock(return_value=True),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("app.channels.service.get_channel_service", lambda: service)
|
||||||
|
app = _make_app(
|
||||||
|
config,
|
||||||
|
repo,
|
||||||
|
{
|
||||||
|
"feishu": {
|
||||||
|
"enabled": True,
|
||||||
|
"app_id": "runtime-app-id",
|
||||||
|
"app_secret": "runtime-secret",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
runtime_config_store=runtime_config_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
disconnect_response = client.delete("/api/channels/feishu/runtime-config")
|
||||||
|
providers_response = client.get("/api/channels/providers")
|
||||||
|
|
||||||
|
assert disconnect_response.status_code == 200
|
||||||
|
disconnected = disconnect_response.json()
|
||||||
|
assert disconnected["provider"] == "feishu"
|
||||||
|
assert disconnected["configured"] is False
|
||||||
|
assert disconnected["connectable"] is False
|
||||||
|
assert disconnected["connection_status"] == "not_connected"
|
||||||
|
assert "feishu" not in app.state.channels_config
|
||||||
|
service.remove_channel.assert_awaited_once_with("feishu")
|
||||||
|
service.configure_channel.assert_not_awaited()
|
||||||
|
|
||||||
|
assert providers_response.status_code == 200
|
||||||
|
by_provider = {item["provider"]: item for item in providers_response.json()["providers"]}
|
||||||
|
assert by_provider["feishu"]["configured"] is False
|
||||||
|
assert by_provider["feishu"]["connection_status"] == "not_connected"
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
def test_disconnect_provider_runtime_config_revokes_current_user_provider_connections(tmp_path):
|
def test_disconnect_provider_runtime_config_revokes_current_user_provider_connections(tmp_path):
|
||||||
import anyio
|
import anyio
|
||||||
|
|
||||||
|
|||||||
@@ -3504,6 +3504,43 @@ class TestChannelService:
|
|||||||
"app_token": "xapp-ui",
|
"app_token": "xapp-ui",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def test_from_app_config_runtime_disconnect_suppresses_file_channel_config(self, monkeypatch, tmp_path):
|
||||||
|
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
|
||||||
|
from app.channels.service import ChannelService
|
||||||
|
from deerflow.config import paths as paths_module
|
||||||
|
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||||
|
|
||||||
|
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
|
||||||
|
monkeypatch.setattr(paths_module, "_paths", None)
|
||||||
|
ChannelRuntimeConfigStore().set_provider_config(
|
||||||
|
"feishu",
|
||||||
|
{
|
||||||
|
"enabled": False,
|
||||||
|
"_runtime_disabled": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
app_config = SimpleNamespace(
|
||||||
|
model_extra={
|
||||||
|
"channels": {
|
||||||
|
"feishu": {
|
||||||
|
"enabled": True,
|
||||||
|
"app_id": "file-app-id",
|
||||||
|
"app_secret": "file-secret",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
channel_connections=ChannelConnectionsConfig.model_validate(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"feishu": {"enabled": True},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
service = ChannelService.from_app_config(app_config)
|
||||||
|
|
||||||
|
assert "feishu" not in service._config
|
||||||
|
|
||||||
def test_start_retries_configured_channel_until_ready(self, monkeypatch):
|
def test_start_retries_configured_channel_until_ready(self, monkeypatch):
|
||||||
from app.channels.service import ChannelService
|
from app.channels.service import ChannelService
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user