Persist IM runtime config locally

This commit is contained in:
taohe
2026-06-11 15:58:40 +08:00
parent 09872af36c
commit ade4a55cfe
8 changed files with 393 additions and 71 deletions
@@ -0,0 +1,129 @@
"""Local persistence for runtime IM channel configuration."""
from __future__ import annotations
import json
import logging
import tempfile
import threading
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
class ChannelRuntimeConfigStore:
"""JSON-backed store for channel credentials entered from the UI.
This intentionally mirrors ``ChannelStore``: local/private deployments get
durable runtime configuration without needing a public callback URL or a
config.yaml edit.
"""
def __init__(self, path: str | Path | None = None) -> None:
if path is None:
from deerflow.config.paths import get_paths
path = Path(get_paths().base_dir) / "channels" / "runtime-config.json"
self._path = Path(path)
self._path.parent.mkdir(parents=True, exist_ok=True)
self._data: dict[str, dict[str, Any]] = self._load()
self._lock = threading.Lock()
def _load(self) -> dict[str, dict[str, Any]]:
if self._path.exists():
try:
raw = json.loads(self._path.read_text(encoding="utf-8"))
except (json.JSONDecodeError, OSError):
logger.warning("Corrupt channel runtime config store at %s, starting fresh", self._path)
return {}
if isinstance(raw, dict):
return {str(name): dict(value) for name, value in raw.items() if isinstance(value, dict)}
return {}
def _save(self) -> None:
fd = tempfile.NamedTemporaryFile(
mode="w",
dir=self._path.parent,
suffix=".tmp",
delete=False,
)
try:
json.dump(self._data, fd, indent=2, ensure_ascii=False)
fd.close()
Path(fd.name).replace(self._path)
try:
self._path.chmod(0o600)
except OSError:
logger.debug("Unable to chmod channel runtime config store at %s", self._path, exc_info=True)
except BaseException:
fd.close()
Path(fd.name).unlink(missing_ok=True)
raise
def load_all(self) -> dict[str, dict[str, Any]]:
with self._lock:
return {name: dict(config) for name, config in self._data.items()}
def get_provider_config(self, provider: str) -> dict[str, Any] | None:
with self._lock:
config = self._data.get(provider)
return dict(config) if isinstance(config, dict) else None
def set_provider_config(self, provider: str, config: dict[str, Any]) -> None:
with self._lock:
self._data[provider] = dict(config)
self._save()
def _provider_enabled(channel_connections_config: Any, provider: str) -> bool:
provider_config = getattr(channel_connections_config, provider, None)
return bool(getattr(provider_config, "enabled", False))
def merge_runtime_channel_configs(
channels_config: dict[str, Any],
channel_connections_config: Any,
*,
store: ChannelRuntimeConfigStore | None = None,
) -> None:
"""Merge persisted runtime provider config into ``channels_config`` in-place."""
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
return
runtime_store = store or ChannelRuntimeConfigStore()
for provider, runtime_config in runtime_store.load_all().items():
if not _provider_enabled(channel_connections_config, provider):
continue
existing = channels_config.get(provider)
merged = dict(runtime_config)
if isinstance(existing, dict):
merged.update(existing)
channels_config[provider] = merged
def apply_runtime_connection_config(
channel_connections_config: Any,
*,
store: ChannelRuntimeConfigStore | None = None,
) -> Any:
"""Apply persisted connection metadata that lives outside ``channels``.
Telegram uses a bot username for deep links; UI-entered values are stored
with the runtime channel config so local restarts keep the provider
configured.
"""
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
return channel_connections_config
runtime_store = store or ChannelRuntimeConfigStore()
telegram_runtime_config = runtime_store.get_provider_config("telegram")
bot_username = ""
if isinstance(telegram_runtime_config, dict):
bot_username = str(telegram_runtime_config.get("bot_username") or "").strip()
if not bot_username or not _provider_enabled(channel_connections_config, "telegram"):
return channel_connections_config
config = channel_connections_config.model_copy(deep=True)
config.telegram.bot_username = bot_username
return config
+2 -2
View File
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any
from app.channels.base import Channel
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
from app.channels.message_bus import MessageBus
from app.channels.runtime_config_store import merge_runtime_channel_configs
from app.channels.store import ChannelStore
logger = logging.getLogger(__name__)
@@ -54,8 +55,7 @@ def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str,
def _merge_channel_connection_runtime_config(channels_config: dict[str, Any], app_config: AppConfig) -> None:
connection_config = getattr(app_config, "channel_connections", None)
if connection_config is None or not getattr(connection_config, "enabled", False):
return
merge_runtime_channel_configs(channels_config, connection_config)
def _make_connection_repo(app_config: AppConfig):
@@ -10,6 +10,11 @@ from typing import Any
from fastapi import APIRouter, HTTPException, Request, Response
from pydantic import BaseModel, Field
from app.channels.runtime_config_store import (
ChannelRuntimeConfigStore,
apply_runtime_connection_config,
merge_runtime_channel_configs,
)
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
from deerflow.persistence.channel_connections import ChannelConnectionRepository
from deerflow.persistence.engine import get_session_factory
@@ -132,11 +137,22 @@ def _get_app_config():
return get_app_config()
def _get_runtime_config_store(request: Request) -> ChannelRuntimeConfigStore:
store = getattr(request.app.state, "channel_runtime_config_store", None)
if isinstance(store, ChannelRuntimeConfigStore):
return store
store = ChannelRuntimeConfigStore()
request.app.state.channel_runtime_config_store = store
return store
def _get_channel_connections_config(request: Request) -> ChannelConnectionsConfig:
config = getattr(request.app.state, "channel_connections_config", None)
if isinstance(config, ChannelConnectionsConfig):
if not isinstance(config, ChannelConnectionsConfig):
config = _get_app_config().channel_connections
config = apply_runtime_connection_config(config, store=_get_runtime_config_store(request))
request.app.state.channel_connections_config = config
return config
return _get_app_config().channel_connections
def _get_channels_config(request: Request) -> dict[str, Any]:
@@ -147,7 +163,14 @@ def _get_channels_config(request: Request) -> dict[str, Any]:
app_config = _get_app_config()
extra = app_config.model_extra or {}
channels_config = extra.get("channels")
return dict(channels_config) if isinstance(channels_config, dict) else {}
result = dict(channels_config) if isinstance(channels_config, dict) else {}
merge_runtime_channel_configs(
result,
_get_channel_connections_config(request),
store=_get_runtime_config_store(request),
)
request.app.state.channels_config = result
return result
def _get_repository(request: Request, config: ChannelConnectionsConfig) -> ChannelConnectionRepository:
@@ -436,6 +459,7 @@ async def configure_channel_provider_runtime(
runtime_config[key] = values[key]
if provider == "telegram":
runtime_config["bot_username"] = values["bot_username"]
provider_config.bot_username = values["bot_username"]
request.app.state.channel_connections_config = config
@@ -447,4 +471,6 @@ async def configure_channel_provider_runtime(
display_name = _PROVIDER_META[provider]["display_name"]
raise HTTPException(status_code=400, detail=f"Failed to start {display_name} channel. Check the values and try again.")
_get_runtime_config_store(request).set_provider_config(provider, runtime_config)
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
@@ -7,6 +7,7 @@ from uuid import UUID
from _router_auth_helpers import make_authed_test_app
from fastapi.testclient import TestClient
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
from app.gateway.auth.models import User
from app.gateway.routers import channel_connections
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
@@ -29,11 +30,21 @@ async def _make_repo(tmp_path):
return ChannelConnectionRepository(get_session_factory())
def _make_app(config: ChannelConnectionsConfig, repo, channels_config: dict | None = None):
def _make_app(
config: ChannelConnectionsConfig,
repo,
channels_config: dict | None = None,
*,
runtime_config_store: ChannelRuntimeConfigStore | None = None,
set_channels_config_state: bool = True,
):
app = make_authed_test_app(user_factory=_user)
app.state.channel_connections_config = config
app.state.channel_connection_repo = repo
if set_channels_config_state:
app.state.channels_config = channels_config or {}
if runtime_config_store is not None:
app.state.channel_runtime_config_store = runtime_config_store
app.include_router(channel_connections.router)
return app
@@ -398,6 +409,56 @@ def test_configure_provider_runtime_credentials_enables_connect_without_file_edi
anyio.run(repo.close)
def test_configure_provider_runtime_credentials_survive_local_restart(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
config = ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"slack": {"enabled": True},
}
)
runtime_config_path = tmp_path / "channels" / "runtime-config.json"
first_app = _make_app(
config,
repo,
{},
runtime_config_store=ChannelRuntimeConfigStore(runtime_config_path),
)
with TestClient(first_app) as client:
configure_response = client.post(
"/api/channels/slack/runtime-config",
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
)
assert configure_response.status_code == 200
restarted_app = _make_app(
config,
repo,
runtime_config_store=ChannelRuntimeConfigStore(runtime_config_path),
set_channels_config_state=False,
)
with TestClient(restarted_app) as client:
response = client.get("/api/channels/providers")
assert response.status_code == 200
by_provider = {item["provider"]: item for item in response.json()["providers"]}
assert by_provider["slack"]["configured"] is True
assert by_provider["slack"]["connectable"] is True
assert by_provider["slack"]["connection_status"] == "connected"
assert restarted_app.state.channels_config["slack"] == {
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
}
anyio.run(repo.close)
def test_disconnect_connection_revokes_current_user_connection(tmp_path):
import anyio
+59 -2
View File
@@ -3287,10 +3287,17 @@ class TestChannelService:
assert service._config == {"telegram": {"enabled": False}}
def test_from_app_config_does_not_create_runtime_channels_from_channel_connections(self):
def test_from_app_config_does_not_create_runtime_channels_from_channel_connections(
self,
monkeypatch,
tmp_path,
):
from app.channels.service import ChannelService
from deerflow.config import paths as paths_module
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
monkeypatch.setattr(paths_module, "_paths", None)
app_config = SimpleNamespace(
model_extra={},
channel_connections=ChannelConnectionsConfig.model_validate(
@@ -3307,10 +3314,26 @@ class TestChannelService:
assert service._config == {}
def test_from_app_config_preserves_existing_runtime_channels_with_channel_connections_enabled(self):
def test_from_app_config_preserves_existing_runtime_channels_with_channel_connections_enabled(
self,
monkeypatch,
tmp_path,
):
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
from app.channels.service import ChannelService
from deerflow.config import paths as paths_module
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
monkeypatch.setattr(paths_module, "_paths", None)
ChannelRuntimeConfigStore().set_provider_config(
"slack",
{
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
},
)
app_config = SimpleNamespace(
model_extra={
"channels": {
@@ -3335,6 +3358,40 @@ class TestChannelService:
assert service._config["slack"]["app_token"] == "xapp"
assert service._config["discord"]["bot_token"] == "discord-bot-token"
def test_from_app_config_loads_persisted_runtime_channel_config(self, monkeypatch, tmp_path):
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
from app.channels.service import ChannelService
from deerflow.config import paths as paths_module
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
monkeypatch.setattr(paths_module, "_paths", None)
ChannelRuntimeConfigStore().set_provider_config(
"slack",
{
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
},
)
app_config = SimpleNamespace(
model_extra={},
channel_connections=ChannelConnectionsConfig.model_validate(
{
"enabled": True,
"slack": {"enabled": True},
}
),
)
service = ChannelService.from_app_config(app_config)
assert service._config["slack"] == {
"enabled": True,
"bot_token": "xoxb-ui",
"app_token": "xapp-ui",
}
def test_connection_repo_is_forwarded_to_manager(self):
from app.channels.service import ChannelService
+7 -61
View File
@@ -1,5 +1,4 @@
import type { AIMessage, Message, Run } from "@langchain/langgraph-sdk";
import type { ThreadsClient } from "@langchain/langgraph-sdk/client";
import { useStream } from "@langchain/langgraph-sdk/react";
import {
type QueryClient,
@@ -24,6 +23,11 @@ import type { UploadedFileInfo } from "../uploads";
import { promptInputFilePartToFile, uploadFiles } from "../uploads";
import { fetchThreadTokenUsage } from "./api";
import {
buildThreadsSearchQueryOptions,
DEFAULT_THREAD_SEARCH_PARAMS,
type ThreadSearchParams,
} from "./thread-search-query";
import { threadTokenUsageQueryKey } from "./token-usage";
import type {
AgentThread,
@@ -1034,69 +1038,11 @@ export function useThreadHistory(
}
export function useThreads(
params: Parameters<ThreadsClient["search"]>[0] = {
limit: 50,
sortBy: "updated_at",
sortOrder: "desc",
select: ["thread_id", "updated_at", "values", "metadata"],
},
params: ThreadSearchParams = DEFAULT_THREAD_SEARCH_PARAMS,
) {
const apiClient = getAPIClient();
return useQuery<AgentThread[]>({
queryKey: ["threads", "search", params],
queryFn: async () => {
const maxResults = params.limit;
const initialOffset = params.offset ?? 0;
const DEFAULT_PAGE_SIZE = 50;
// Preserve prior semantics: if a non-positive limit is explicitly provided,
// delegate to a single search call with the original parameters.
if (maxResults !== undefined && maxResults <= 0) {
const response =
await apiClient.threads.search<AgentThreadState>(params);
return response as AgentThread[];
}
const pageSize =
typeof maxResults === "number" && maxResults > 0
? Math.min(DEFAULT_PAGE_SIZE, maxResults)
: DEFAULT_PAGE_SIZE;
const threads: AgentThread[] = [];
let offset = initialOffset;
while (true) {
if (typeof maxResults === "number" && threads.length >= maxResults) {
break;
}
const currentLimit =
typeof maxResults === "number"
? Math.min(pageSize, maxResults - threads.length)
: pageSize;
if (typeof maxResults === "number" && currentLimit <= 0) {
break;
}
const response = (await apiClient.threads.search<AgentThreadState>({
...params,
limit: currentLimit,
offset,
})) as AgentThread[];
threads.push(...response);
if (response.length < currentLimit) {
break;
}
offset += response.length;
}
return threads;
},
refetchOnWindowFocus: false,
...buildThreadsSearchQueryOptions(apiClient, params),
});
}
@@ -0,0 +1,84 @@
import type { ThreadsClient } from "@langchain/langgraph-sdk/client";
import type { AgentThread, AgentThreadState } from "./types";
type ThreadsSearchClient = {
threads: {
search: ThreadsClient["search"];
};
};
export type ThreadSearchParams = NonNullable<Parameters<ThreadsClient["search"]>[0]>;
export const DEFAULT_THREAD_SEARCH_PARAMS: ThreadSearchParams = {
limit: 50,
sortBy: "updated_at",
sortOrder: "desc",
select: ["thread_id", "updated_at", "values", "metadata"],
};
export const THREAD_SEARCH_REFETCH_INTERVAL_MS = 5000;
export function buildThreadsSearchQueryOptions(
apiClient: ThreadsSearchClient,
params: ThreadSearchParams = DEFAULT_THREAD_SEARCH_PARAMS,
) {
return {
queryKey: ["threads", "search", params],
queryFn: async () => {
const maxResults = params.limit;
const initialOffset = params.offset ?? 0;
const DEFAULT_PAGE_SIZE = 50;
// Preserve prior semantics: if a non-positive limit is explicitly provided,
// delegate to a single search call with the original parameters.
if (maxResults !== undefined && maxResults <= 0) {
const response =
await apiClient.threads.search<AgentThreadState>(params);
return response as AgentThread[];
}
const pageSize =
typeof maxResults === "number" && maxResults > 0
? Math.min(DEFAULT_PAGE_SIZE, maxResults)
: DEFAULT_PAGE_SIZE;
const threads: AgentThread[] = [];
let offset = initialOffset;
while (true) {
if (typeof maxResults === "number" && threads.length >= maxResults) {
break;
}
const currentLimit =
typeof maxResults === "number"
? Math.min(pageSize, maxResults - threads.length)
: pageSize;
if (typeof maxResults === "number" && currentLimit <= 0) {
break;
}
const response = (await apiClient.threads.search<AgentThreadState>({
...params,
limit: currentLimit,
offset,
})) as AgentThread[];
threads.push(...response);
if (response.length < currentLimit) {
break;
}
offset += response.length;
}
return threads;
},
refetchInterval: THREAD_SEARCH_REFETCH_INTERVAL_MS,
refetchIntervalInBackground: false,
refetchOnWindowFocus: false,
};
}
@@ -0,0 +1,19 @@
import { expect, test, vi } from "vitest";
import {
buildThreadsSearchQueryOptions,
DEFAULT_THREAD_SEARCH_PARAMS,
THREAD_SEARCH_REFETCH_INTERVAL_MS,
} from "@/core/threads/thread-search-query";
test("thread search query refreshes so IM-created sessions appear in the sidebar", () => {
const search = vi.fn();
const options = buildThreadsSearchQueryOptions(
{ threads: { search } },
DEFAULT_THREAD_SEARCH_PARAMS,
);
expect(options.refetchInterval).toBe(THREAD_SEARCH_REFETCH_INTERVAL_MS);
expect(options.refetchIntervalInBackground).toBe(false);
expect(options.refetchOnWindowFocus).toBe(false);
});