diff --git a/backend/app/channels/runtime_config_store.py b/backend/app/channels/runtime_config_store.py new file mode 100644 index 000000000..6305806cc --- /dev/null +++ b/backend/app/channels/runtime_config_store.py @@ -0,0 +1,129 @@ +"""Local persistence for runtime IM channel configuration.""" + +from __future__ import annotations + +import json +import logging +import tempfile +import threading +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +class ChannelRuntimeConfigStore: + """JSON-backed store for channel credentials entered from the UI. + + This intentionally mirrors ``ChannelStore``: local/private deployments get + durable runtime configuration without needing a public callback URL or a + config.yaml edit. + """ + + def __init__(self, path: str | Path | None = None) -> None: + if path is None: + from deerflow.config.paths import get_paths + + path = Path(get_paths().base_dir) / "channels" / "runtime-config.json" + self._path = Path(path) + self._path.parent.mkdir(parents=True, exist_ok=True) + self._data: dict[str, dict[str, Any]] = self._load() + self._lock = threading.Lock() + + def _load(self) -> dict[str, dict[str, Any]]: + if self._path.exists(): + try: + raw = json.loads(self._path.read_text(encoding="utf-8")) + except (json.JSONDecodeError, OSError): + logger.warning("Corrupt channel runtime config store at %s, starting fresh", self._path) + return {} + if isinstance(raw, dict): + return {str(name): dict(value) for name, value in raw.items() if isinstance(value, dict)} + return {} + + def _save(self) -> None: + fd = tempfile.NamedTemporaryFile( + mode="w", + dir=self._path.parent, + suffix=".tmp", + delete=False, + ) + try: + json.dump(self._data, fd, indent=2, ensure_ascii=False) + fd.close() + Path(fd.name).replace(self._path) + try: + self._path.chmod(0o600) + except OSError: + logger.debug("Unable to chmod channel runtime config store at %s", self._path, exc_info=True) + except BaseException: + fd.close() + Path(fd.name).unlink(missing_ok=True) + raise + + def load_all(self) -> dict[str, dict[str, Any]]: + with self._lock: + return {name: dict(config) for name, config in self._data.items()} + + def get_provider_config(self, provider: str) -> dict[str, Any] | None: + with self._lock: + config = self._data.get(provider) + return dict(config) if isinstance(config, dict) else None + + def set_provider_config(self, provider: str, config: dict[str, Any]) -> None: + with self._lock: + self._data[provider] = dict(config) + self._save() + + +def _provider_enabled(channel_connections_config: Any, provider: str) -> bool: + provider_config = getattr(channel_connections_config, provider, None) + return bool(getattr(provider_config, "enabled", False)) + + +def merge_runtime_channel_configs( + channels_config: dict[str, Any], + channel_connections_config: Any, + *, + store: ChannelRuntimeConfigStore | None = None, +) -> None: + """Merge persisted runtime provider config into ``channels_config`` in-place.""" + if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False): + return + + runtime_store = store or ChannelRuntimeConfigStore() + for provider, runtime_config in runtime_store.load_all().items(): + if not _provider_enabled(channel_connections_config, provider): + continue + existing = channels_config.get(provider) + merged = dict(runtime_config) + if isinstance(existing, dict): + merged.update(existing) + channels_config[provider] = merged + + +def apply_runtime_connection_config( + channel_connections_config: Any, + *, + store: ChannelRuntimeConfigStore | None = None, +) -> Any: + """Apply persisted connection metadata that lives outside ``channels``. + + Telegram uses a bot username for deep links; UI-entered values are stored + with the runtime channel config so local restarts keep the provider + configured. + """ + if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False): + return channel_connections_config + + runtime_store = store or ChannelRuntimeConfigStore() + telegram_runtime_config = runtime_store.get_provider_config("telegram") + bot_username = "" + if isinstance(telegram_runtime_config, dict): + bot_username = str(telegram_runtime_config.get("bot_username") or "").strip() + if not bot_username or not _provider_enabled(channel_connections_config, "telegram"): + return channel_connections_config + + config = channel_connections_config.model_copy(deep=True) + config.telegram.bot_username = bot_username + return config diff --git a/backend/app/channels/service.py b/backend/app/channels/service.py index 1261f8096..c742ced5b 100644 --- a/backend/app/channels/service.py +++ b/backend/app/channels/service.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any from app.channels.base import Channel from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager from app.channels.message_bus import MessageBus +from app.channels.runtime_config_store import merge_runtime_channel_configs from app.channels.store import ChannelStore logger = logging.getLogger(__name__) @@ -54,8 +55,7 @@ def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str, def _merge_channel_connection_runtime_config(channels_config: dict[str, Any], app_config: AppConfig) -> None: connection_config = getattr(app_config, "channel_connections", None) - if connection_config is None or not getattr(connection_config, "enabled", False): - return + merge_runtime_channel_configs(channels_config, connection_config) def _make_connection_repo(app_config: AppConfig): diff --git a/backend/app/gateway/routers/channel_connections.py b/backend/app/gateway/routers/channel_connections.py index c4f45be5e..2ed9aaa1f 100644 --- a/backend/app/gateway/routers/channel_connections.py +++ b/backend/app/gateway/routers/channel_connections.py @@ -10,6 +10,11 @@ from typing import Any from fastapi import APIRouter, HTTPException, Request, Response from pydantic import BaseModel, Field +from app.channels.runtime_config_store import ( + ChannelRuntimeConfigStore, + apply_runtime_connection_config, + merge_runtime_channel_configs, +) from deerflow.config.channel_connections_config import ChannelConnectionsConfig from deerflow.persistence.channel_connections import ChannelConnectionRepository from deerflow.persistence.engine import get_session_factory @@ -132,11 +137,22 @@ def _get_app_config(): return get_app_config() +def _get_runtime_config_store(request: Request) -> ChannelRuntimeConfigStore: + store = getattr(request.app.state, "channel_runtime_config_store", None) + if isinstance(store, ChannelRuntimeConfigStore): + return store + store = ChannelRuntimeConfigStore() + request.app.state.channel_runtime_config_store = store + return store + + def _get_channel_connections_config(request: Request) -> ChannelConnectionsConfig: config = getattr(request.app.state, "channel_connections_config", None) - if isinstance(config, ChannelConnectionsConfig): - return config - return _get_app_config().channel_connections + if not isinstance(config, ChannelConnectionsConfig): + config = _get_app_config().channel_connections + config = apply_runtime_connection_config(config, store=_get_runtime_config_store(request)) + request.app.state.channel_connections_config = config + return config def _get_channels_config(request: Request) -> dict[str, Any]: @@ -147,7 +163,14 @@ def _get_channels_config(request: Request) -> dict[str, Any]: app_config = _get_app_config() extra = app_config.model_extra or {} channels_config = extra.get("channels") - return dict(channels_config) if isinstance(channels_config, dict) else {} + result = dict(channels_config) if isinstance(channels_config, dict) else {} + merge_runtime_channel_configs( + result, + _get_channel_connections_config(request), + store=_get_runtime_config_store(request), + ) + request.app.state.channels_config = result + return result def _get_repository(request: Request, config: ChannelConnectionsConfig) -> ChannelConnectionRepository: @@ -436,6 +459,7 @@ async def configure_channel_provider_runtime( runtime_config[key] = values[key] if provider == "telegram": + runtime_config["bot_username"] = values["bot_username"] provider_config.bot_username = values["bot_username"] request.app.state.channel_connections_config = config @@ -447,4 +471,6 @@ async def configure_channel_provider_runtime( display_name = _PROVIDER_META[provider]["display_name"] raise HTTPException(status_code=400, detail=f"Failed to start {display_name} channel. Check the values and try again.") + _get_runtime_config_store(request).set_provider_config(provider, runtime_config) + return _provider_response(config, channels_config, provider, _PROVIDER_META[provider]) diff --git a/backend/tests/test_channel_connections_router.py b/backend/tests/test_channel_connections_router.py index 108af262e..239169d06 100644 --- a/backend/tests/test_channel_connections_router.py +++ b/backend/tests/test_channel_connections_router.py @@ -7,6 +7,7 @@ from uuid import UUID from _router_auth_helpers import make_authed_test_app from fastapi.testclient import TestClient +from app.channels.runtime_config_store import ChannelRuntimeConfigStore from app.gateway.auth.models import User from app.gateway.routers import channel_connections from deerflow.config.channel_connections_config import ChannelConnectionsConfig @@ -29,11 +30,21 @@ async def _make_repo(tmp_path): return ChannelConnectionRepository(get_session_factory()) -def _make_app(config: ChannelConnectionsConfig, repo, channels_config: dict | None = None): +def _make_app( + config: ChannelConnectionsConfig, + repo, + channels_config: dict | None = None, + *, + runtime_config_store: ChannelRuntimeConfigStore | None = None, + set_channels_config_state: bool = True, +): app = make_authed_test_app(user_factory=_user) app.state.channel_connections_config = config app.state.channel_connection_repo = repo - app.state.channels_config = channels_config or {} + if set_channels_config_state: + app.state.channels_config = channels_config or {} + if runtime_config_store is not None: + app.state.channel_runtime_config_store = runtime_config_store app.include_router(channel_connections.router) return app @@ -398,6 +409,56 @@ def test_configure_provider_runtime_credentials_enables_connect_without_file_edi anyio.run(repo.close) +def test_configure_provider_runtime_credentials_survive_local_restart(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + config = ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "slack": {"enabled": True}, + } + ) + runtime_config_path = tmp_path / "channels" / "runtime-config.json" + first_app = _make_app( + config, + repo, + {}, + runtime_config_store=ChannelRuntimeConfigStore(runtime_config_path), + ) + + with TestClient(first_app) as client: + configure_response = client.post( + "/api/channels/slack/runtime-config", + json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}}, + ) + + assert configure_response.status_code == 200 + + restarted_app = _make_app( + config, + repo, + runtime_config_store=ChannelRuntimeConfigStore(runtime_config_path), + set_channels_config_state=False, + ) + + with TestClient(restarted_app) as client: + response = client.get("/api/channels/providers") + + assert response.status_code == 200 + by_provider = {item["provider"]: item for item in response.json()["providers"]} + assert by_provider["slack"]["configured"] is True + assert by_provider["slack"]["connectable"] is True + assert by_provider["slack"]["connection_status"] == "connected" + assert restarted_app.state.channels_config["slack"] == { + "enabled": True, + "bot_token": "xoxb-ui", + "app_token": "xapp-ui", + } + + anyio.run(repo.close) + + def test_disconnect_connection_revokes_current_user_connection(tmp_path): import anyio diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index 8d00de577..07e3aa8be 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -3287,10 +3287,17 @@ class TestChannelService: assert service._config == {"telegram": {"enabled": False}} - def test_from_app_config_does_not_create_runtime_channels_from_channel_connections(self): + def test_from_app_config_does_not_create_runtime_channels_from_channel_connections( + self, + monkeypatch, + tmp_path, + ): from app.channels.service import ChannelService + from deerflow.config import paths as paths_module from deerflow.config.channel_connections_config import ChannelConnectionsConfig + monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path)) + monkeypatch.setattr(paths_module, "_paths", None) app_config = SimpleNamespace( model_extra={}, channel_connections=ChannelConnectionsConfig.model_validate( @@ -3307,10 +3314,26 @@ class TestChannelService: assert service._config == {} - def test_from_app_config_preserves_existing_runtime_channels_with_channel_connections_enabled(self): + def test_from_app_config_preserves_existing_runtime_channels_with_channel_connections_enabled( + self, + monkeypatch, + tmp_path, + ): + from app.channels.runtime_config_store import ChannelRuntimeConfigStore from app.channels.service import ChannelService + from deerflow.config import paths as paths_module from deerflow.config.channel_connections_config import ChannelConnectionsConfig + monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path)) + monkeypatch.setattr(paths_module, "_paths", None) + ChannelRuntimeConfigStore().set_provider_config( + "slack", + { + "enabled": True, + "bot_token": "xoxb-ui", + "app_token": "xapp-ui", + }, + ) app_config = SimpleNamespace( model_extra={ "channels": { @@ -3335,6 +3358,40 @@ class TestChannelService: assert service._config["slack"]["app_token"] == "xapp" assert service._config["discord"]["bot_token"] == "discord-bot-token" + def test_from_app_config_loads_persisted_runtime_channel_config(self, monkeypatch, tmp_path): + from app.channels.runtime_config_store import ChannelRuntimeConfigStore + from app.channels.service import ChannelService + from deerflow.config import paths as paths_module + from deerflow.config.channel_connections_config import ChannelConnectionsConfig + + monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path)) + monkeypatch.setattr(paths_module, "_paths", None) + ChannelRuntimeConfigStore().set_provider_config( + "slack", + { + "enabled": True, + "bot_token": "xoxb-ui", + "app_token": "xapp-ui", + }, + ) + app_config = SimpleNamespace( + model_extra={}, + channel_connections=ChannelConnectionsConfig.model_validate( + { + "enabled": True, + "slack": {"enabled": True}, + } + ), + ) + + service = ChannelService.from_app_config(app_config) + + assert service._config["slack"] == { + "enabled": True, + "bot_token": "xoxb-ui", + "app_token": "xapp-ui", + } + def test_connection_repo_is_forwarded_to_manager(self): from app.channels.service import ChannelService diff --git a/frontend/src/core/threads/hooks.ts b/frontend/src/core/threads/hooks.ts index 2ac1a1814..b8f6ceee5 100644 --- a/frontend/src/core/threads/hooks.ts +++ b/frontend/src/core/threads/hooks.ts @@ -1,5 +1,4 @@ import type { AIMessage, Message, Run } from "@langchain/langgraph-sdk"; -import type { ThreadsClient } from "@langchain/langgraph-sdk/client"; import { useStream } from "@langchain/langgraph-sdk/react"; import { type QueryClient, @@ -24,6 +23,11 @@ import type { UploadedFileInfo } from "../uploads"; import { promptInputFilePartToFile, uploadFiles } from "../uploads"; import { fetchThreadTokenUsage } from "./api"; +import { + buildThreadsSearchQueryOptions, + DEFAULT_THREAD_SEARCH_PARAMS, + type ThreadSearchParams, +} from "./thread-search-query"; import { threadTokenUsageQueryKey } from "./token-usage"; import type { AgentThread, @@ -1034,69 +1038,11 @@ export function useThreadHistory( } export function useThreads( - params: Parameters[0] = { - limit: 50, - sortBy: "updated_at", - sortOrder: "desc", - select: ["thread_id", "updated_at", "values", "metadata"], - }, + params: ThreadSearchParams = DEFAULT_THREAD_SEARCH_PARAMS, ) { const apiClient = getAPIClient(); return useQuery({ - queryKey: ["threads", "search", params], - queryFn: async () => { - const maxResults = params.limit; - const initialOffset = params.offset ?? 0; - const DEFAULT_PAGE_SIZE = 50; - - // Preserve prior semantics: if a non-positive limit is explicitly provided, - // delegate to a single search call with the original parameters. - if (maxResults !== undefined && maxResults <= 0) { - const response = - await apiClient.threads.search(params); - return response as AgentThread[]; - } - - const pageSize = - typeof maxResults === "number" && maxResults > 0 - ? Math.min(DEFAULT_PAGE_SIZE, maxResults) - : DEFAULT_PAGE_SIZE; - - const threads: AgentThread[] = []; - let offset = initialOffset; - - while (true) { - if (typeof maxResults === "number" && threads.length >= maxResults) { - break; - } - - const currentLimit = - typeof maxResults === "number" - ? Math.min(pageSize, maxResults - threads.length) - : pageSize; - - if (typeof maxResults === "number" && currentLimit <= 0) { - break; - } - - const response = (await apiClient.threads.search({ - ...params, - limit: currentLimit, - offset, - })) as AgentThread[]; - - threads.push(...response); - - if (response.length < currentLimit) { - break; - } - - offset += response.length; - } - - return threads; - }, - refetchOnWindowFocus: false, + ...buildThreadsSearchQueryOptions(apiClient, params), }); } diff --git a/frontend/src/core/threads/thread-search-query.ts b/frontend/src/core/threads/thread-search-query.ts new file mode 100644 index 000000000..758df283f --- /dev/null +++ b/frontend/src/core/threads/thread-search-query.ts @@ -0,0 +1,84 @@ +import type { ThreadsClient } from "@langchain/langgraph-sdk/client"; + +import type { AgentThread, AgentThreadState } from "./types"; + +type ThreadsSearchClient = { + threads: { + search: ThreadsClient["search"]; + }; +}; + +export type ThreadSearchParams = NonNullable[0]>; + +export const DEFAULT_THREAD_SEARCH_PARAMS: ThreadSearchParams = { + limit: 50, + sortBy: "updated_at", + sortOrder: "desc", + select: ["thread_id", "updated_at", "values", "metadata"], +}; + +export const THREAD_SEARCH_REFETCH_INTERVAL_MS = 5000; + +export function buildThreadsSearchQueryOptions( + apiClient: ThreadsSearchClient, + params: ThreadSearchParams = DEFAULT_THREAD_SEARCH_PARAMS, +) { + return { + queryKey: ["threads", "search", params], + queryFn: async () => { + const maxResults = params.limit; + const initialOffset = params.offset ?? 0; + const DEFAULT_PAGE_SIZE = 50; + + // Preserve prior semantics: if a non-positive limit is explicitly provided, + // delegate to a single search call with the original parameters. + if (maxResults !== undefined && maxResults <= 0) { + const response = + await apiClient.threads.search(params); + return response as AgentThread[]; + } + + const pageSize = + typeof maxResults === "number" && maxResults > 0 + ? Math.min(DEFAULT_PAGE_SIZE, maxResults) + : DEFAULT_PAGE_SIZE; + + const threads: AgentThread[] = []; + let offset = initialOffset; + + while (true) { + if (typeof maxResults === "number" && threads.length >= maxResults) { + break; + } + + const currentLimit = + typeof maxResults === "number" + ? Math.min(pageSize, maxResults - threads.length) + : pageSize; + + if (typeof maxResults === "number" && currentLimit <= 0) { + break; + } + + const response = (await apiClient.threads.search({ + ...params, + limit: currentLimit, + offset, + })) as AgentThread[]; + + threads.push(...response); + + if (response.length < currentLimit) { + break; + } + + offset += response.length; + } + + return threads; + }, + refetchInterval: THREAD_SEARCH_REFETCH_INTERVAL_MS, + refetchIntervalInBackground: false, + refetchOnWindowFocus: false, + }; +} diff --git a/frontend/tests/unit/core/threads/thread-search-query.test.ts b/frontend/tests/unit/core/threads/thread-search-query.test.ts new file mode 100644 index 000000000..34a008035 --- /dev/null +++ b/frontend/tests/unit/core/threads/thread-search-query.test.ts @@ -0,0 +1,19 @@ +import { expect, test, vi } from "vitest"; + +import { + buildThreadsSearchQueryOptions, + DEFAULT_THREAD_SEARCH_PARAMS, + THREAD_SEARCH_REFETCH_INTERVAL_MS, +} from "@/core/threads/thread-search-query"; + +test("thread search query refreshes so IM-created sessions appear in the sidebar", () => { + const search = vi.fn(); + const options = buildThreadsSearchQueryOptions( + { threads: { search } }, + DEFAULT_THREAD_SEARCH_PARAMS, + ); + + expect(options.refetchInterval).toBe(THREAD_SEARCH_REFETCH_INTERVAL_MS); + expect(options.refetchIntervalInBackground).toBe(false); + expect(options.refetchOnWindowFocus).toBe(false); +});