mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 18:05:58 +00:00
Persist IM runtime config locally
This commit is contained in:
@@ -0,0 +1,129 @@
|
|||||||
|
"""Local persistence for runtime IM channel configuration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelRuntimeConfigStore:
|
||||||
|
"""JSON-backed store for channel credentials entered from the UI.
|
||||||
|
|
||||||
|
This intentionally mirrors ``ChannelStore``: local/private deployments get
|
||||||
|
durable runtime configuration without needing a public callback URL or a
|
||||||
|
config.yaml edit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path: str | Path | None = None) -> None:
|
||||||
|
if path is None:
|
||||||
|
from deerflow.config.paths import get_paths
|
||||||
|
|
||||||
|
path = Path(get_paths().base_dir) / "channels" / "runtime-config.json"
|
||||||
|
self._path = Path(path)
|
||||||
|
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._data: dict[str, dict[str, Any]] = self._load()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def _load(self) -> dict[str, dict[str, Any]]:
|
||||||
|
if self._path.exists():
|
||||||
|
try:
|
||||||
|
raw = json.loads(self._path.read_text(encoding="utf-8"))
|
||||||
|
except (json.JSONDecodeError, OSError):
|
||||||
|
logger.warning("Corrupt channel runtime config store at %s, starting fresh", self._path)
|
||||||
|
return {}
|
||||||
|
if isinstance(raw, dict):
|
||||||
|
return {str(name): dict(value) for name, value in raw.items() if isinstance(value, dict)}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _save(self) -> None:
|
||||||
|
fd = tempfile.NamedTemporaryFile(
|
||||||
|
mode="w",
|
||||||
|
dir=self._path.parent,
|
||||||
|
suffix=".tmp",
|
||||||
|
delete=False,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
json.dump(self._data, fd, indent=2, ensure_ascii=False)
|
||||||
|
fd.close()
|
||||||
|
Path(fd.name).replace(self._path)
|
||||||
|
try:
|
||||||
|
self._path.chmod(0o600)
|
||||||
|
except OSError:
|
||||||
|
logger.debug("Unable to chmod channel runtime config store at %s", self._path, exc_info=True)
|
||||||
|
except BaseException:
|
||||||
|
fd.close()
|
||||||
|
Path(fd.name).unlink(missing_ok=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def load_all(self) -> dict[str, dict[str, Any]]:
|
||||||
|
with self._lock:
|
||||||
|
return {name: dict(config) for name, config in self._data.items()}
|
||||||
|
|
||||||
|
def get_provider_config(self, provider: str) -> dict[str, Any] | None:
|
||||||
|
with self._lock:
|
||||||
|
config = self._data.get(provider)
|
||||||
|
return dict(config) if isinstance(config, dict) else None
|
||||||
|
|
||||||
|
def set_provider_config(self, provider: str, config: dict[str, Any]) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._data[provider] = dict(config)
|
||||||
|
self._save()
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_enabled(channel_connections_config: Any, provider: str) -> bool:
|
||||||
|
provider_config = getattr(channel_connections_config, provider, None)
|
||||||
|
return bool(getattr(provider_config, "enabled", False))
|
||||||
|
|
||||||
|
|
||||||
|
def merge_runtime_channel_configs(
|
||||||
|
channels_config: dict[str, Any],
|
||||||
|
channel_connections_config: Any,
|
||||||
|
*,
|
||||||
|
store: ChannelRuntimeConfigStore | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Merge persisted runtime provider config into ``channels_config`` in-place."""
|
||||||
|
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
runtime_store = store or ChannelRuntimeConfigStore()
|
||||||
|
for provider, runtime_config in runtime_store.load_all().items():
|
||||||
|
if not _provider_enabled(channel_connections_config, provider):
|
||||||
|
continue
|
||||||
|
existing = channels_config.get(provider)
|
||||||
|
merged = dict(runtime_config)
|
||||||
|
if isinstance(existing, dict):
|
||||||
|
merged.update(existing)
|
||||||
|
channels_config[provider] = merged
|
||||||
|
|
||||||
|
|
||||||
|
def apply_runtime_connection_config(
|
||||||
|
channel_connections_config: Any,
|
||||||
|
*,
|
||||||
|
store: ChannelRuntimeConfigStore | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Apply persisted connection metadata that lives outside ``channels``.
|
||||||
|
|
||||||
|
Telegram uses a bot username for deep links; UI-entered values are stored
|
||||||
|
with the runtime channel config so local restarts keep the provider
|
||||||
|
configured.
|
||||||
|
"""
|
||||||
|
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
|
||||||
|
return channel_connections_config
|
||||||
|
|
||||||
|
runtime_store = store or ChannelRuntimeConfigStore()
|
||||||
|
telegram_runtime_config = runtime_store.get_provider_config("telegram")
|
||||||
|
bot_username = ""
|
||||||
|
if isinstance(telegram_runtime_config, dict):
|
||||||
|
bot_username = str(telegram_runtime_config.get("bot_username") or "").strip()
|
||||||
|
if not bot_username or not _provider_enabled(channel_connections_config, "telegram"):
|
||||||
|
return channel_connections_config
|
||||||
|
|
||||||
|
config = channel_connections_config.model_copy(deep=True)
|
||||||
|
config.telegram.bot_username = bot_username
|
||||||
|
return config
|
||||||
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any
|
|||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
|
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
|
||||||
from app.channels.message_bus import MessageBus
|
from app.channels.message_bus import MessageBus
|
||||||
|
from app.channels.runtime_config_store import merge_runtime_channel_configs
|
||||||
from app.channels.store import ChannelStore
|
from app.channels.store import ChannelStore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -54,8 +55,7 @@ def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str,
|
|||||||
|
|
||||||
def _merge_channel_connection_runtime_config(channels_config: dict[str, Any], app_config: AppConfig) -> None:
|
def _merge_channel_connection_runtime_config(channels_config: dict[str, Any], app_config: AppConfig) -> None:
|
||||||
connection_config = getattr(app_config, "channel_connections", None)
|
connection_config = getattr(app_config, "channel_connections", None)
|
||||||
if connection_config is None or not getattr(connection_config, "enabled", False):
|
merge_runtime_channel_configs(channels_config, connection_config)
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def _make_connection_repo(app_config: AppConfig):
|
def _make_connection_repo(app_config: AppConfig):
|
||||||
|
|||||||
@@ -10,6 +10,11 @@ from typing import Any
|
|||||||
from fastapi import APIRouter, HTTPException, Request, Response
|
from fastapi import APIRouter, HTTPException, Request, Response
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.channels.runtime_config_store import (
|
||||||
|
ChannelRuntimeConfigStore,
|
||||||
|
apply_runtime_connection_config,
|
||||||
|
merge_runtime_channel_configs,
|
||||||
|
)
|
||||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
||||||
from deerflow.persistence.engine import get_session_factory
|
from deerflow.persistence.engine import get_session_factory
|
||||||
@@ -132,11 +137,22 @@ def _get_app_config():
|
|||||||
return get_app_config()
|
return get_app_config()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_runtime_config_store(request: Request) -> ChannelRuntimeConfigStore:
|
||||||
|
store = getattr(request.app.state, "channel_runtime_config_store", None)
|
||||||
|
if isinstance(store, ChannelRuntimeConfigStore):
|
||||||
|
return store
|
||||||
|
store = ChannelRuntimeConfigStore()
|
||||||
|
request.app.state.channel_runtime_config_store = store
|
||||||
|
return store
|
||||||
|
|
||||||
|
|
||||||
def _get_channel_connections_config(request: Request) -> ChannelConnectionsConfig:
|
def _get_channel_connections_config(request: Request) -> ChannelConnectionsConfig:
|
||||||
config = getattr(request.app.state, "channel_connections_config", None)
|
config = getattr(request.app.state, "channel_connections_config", None)
|
||||||
if isinstance(config, ChannelConnectionsConfig):
|
if not isinstance(config, ChannelConnectionsConfig):
|
||||||
return config
|
config = _get_app_config().channel_connections
|
||||||
return _get_app_config().channel_connections
|
config = apply_runtime_connection_config(config, store=_get_runtime_config_store(request))
|
||||||
|
request.app.state.channel_connections_config = config
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
def _get_channels_config(request: Request) -> dict[str, Any]:
|
def _get_channels_config(request: Request) -> dict[str, Any]:
|
||||||
@@ -147,7 +163,14 @@ def _get_channels_config(request: Request) -> dict[str, Any]:
|
|||||||
app_config = _get_app_config()
|
app_config = _get_app_config()
|
||||||
extra = app_config.model_extra or {}
|
extra = app_config.model_extra or {}
|
||||||
channels_config = extra.get("channels")
|
channels_config = extra.get("channels")
|
||||||
return dict(channels_config) if isinstance(channels_config, dict) else {}
|
result = dict(channels_config) if isinstance(channels_config, dict) else {}
|
||||||
|
merge_runtime_channel_configs(
|
||||||
|
result,
|
||||||
|
_get_channel_connections_config(request),
|
||||||
|
store=_get_runtime_config_store(request),
|
||||||
|
)
|
||||||
|
request.app.state.channels_config = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _get_repository(request: Request, config: ChannelConnectionsConfig) -> ChannelConnectionRepository:
|
def _get_repository(request: Request, config: ChannelConnectionsConfig) -> ChannelConnectionRepository:
|
||||||
@@ -436,6 +459,7 @@ async def configure_channel_provider_runtime(
|
|||||||
runtime_config[key] = values[key]
|
runtime_config[key] = values[key]
|
||||||
|
|
||||||
if provider == "telegram":
|
if provider == "telegram":
|
||||||
|
runtime_config["bot_username"] = values["bot_username"]
|
||||||
provider_config.bot_username = values["bot_username"]
|
provider_config.bot_username = values["bot_username"]
|
||||||
request.app.state.channel_connections_config = config
|
request.app.state.channel_connections_config = config
|
||||||
|
|
||||||
@@ -447,4 +471,6 @@ async def configure_channel_provider_runtime(
|
|||||||
display_name = _PROVIDER_META[provider]["display_name"]
|
display_name = _PROVIDER_META[provider]["display_name"]
|
||||||
raise HTTPException(status_code=400, detail=f"Failed to start {display_name} channel. Check the values and try again.")
|
raise HTTPException(status_code=400, detail=f"Failed to start {display_name} channel. Check the values and try again.")
|
||||||
|
|
||||||
|
_get_runtime_config_store(request).set_provider_config(provider, runtime_config)
|
||||||
|
|
||||||
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
|
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from uuid import UUID
|
|||||||
from _router_auth_helpers import make_authed_test_app
|
from _router_auth_helpers import make_authed_test_app
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
|
||||||
from app.gateway.auth.models import User
|
from app.gateway.auth.models import User
|
||||||
from app.gateway.routers import channel_connections
|
from app.gateway.routers import channel_connections
|
||||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||||
@@ -29,11 +30,21 @@ async def _make_repo(tmp_path):
|
|||||||
return ChannelConnectionRepository(get_session_factory())
|
return ChannelConnectionRepository(get_session_factory())
|
||||||
|
|
||||||
|
|
||||||
def _make_app(config: ChannelConnectionsConfig, repo, channels_config: dict | None = None):
|
def _make_app(
|
||||||
|
config: ChannelConnectionsConfig,
|
||||||
|
repo,
|
||||||
|
channels_config: dict | None = None,
|
||||||
|
*,
|
||||||
|
runtime_config_store: ChannelRuntimeConfigStore | None = None,
|
||||||
|
set_channels_config_state: bool = True,
|
||||||
|
):
|
||||||
app = make_authed_test_app(user_factory=_user)
|
app = make_authed_test_app(user_factory=_user)
|
||||||
app.state.channel_connections_config = config
|
app.state.channel_connections_config = config
|
||||||
app.state.channel_connection_repo = repo
|
app.state.channel_connection_repo = repo
|
||||||
app.state.channels_config = channels_config or {}
|
if set_channels_config_state:
|
||||||
|
app.state.channels_config = channels_config or {}
|
||||||
|
if runtime_config_store is not None:
|
||||||
|
app.state.channel_runtime_config_store = runtime_config_store
|
||||||
app.include_router(channel_connections.router)
|
app.include_router(channel_connections.router)
|
||||||
return app
|
return app
|
||||||
|
|
||||||
@@ -398,6 +409,56 @@ def test_configure_provider_runtime_credentials_enables_connect_without_file_edi
|
|||||||
anyio.run(repo.close)
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_configure_provider_runtime_credentials_survive_local_restart(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
config = ChannelConnectionsConfig.model_validate(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"slack": {"enabled": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
runtime_config_path = tmp_path / "channels" / "runtime-config.json"
|
||||||
|
first_app = _make_app(
|
||||||
|
config,
|
||||||
|
repo,
|
||||||
|
{},
|
||||||
|
runtime_config_store=ChannelRuntimeConfigStore(runtime_config_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
with TestClient(first_app) as client:
|
||||||
|
configure_response = client.post(
|
||||||
|
"/api/channels/slack/runtime-config",
|
||||||
|
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert configure_response.status_code == 200
|
||||||
|
|
||||||
|
restarted_app = _make_app(
|
||||||
|
config,
|
||||||
|
repo,
|
||||||
|
runtime_config_store=ChannelRuntimeConfigStore(runtime_config_path),
|
||||||
|
set_channels_config_state=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with TestClient(restarted_app) as client:
|
||||||
|
response = client.get("/api/channels/providers")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
by_provider = {item["provider"]: item for item in response.json()["providers"]}
|
||||||
|
assert by_provider["slack"]["configured"] is True
|
||||||
|
assert by_provider["slack"]["connectable"] is True
|
||||||
|
assert by_provider["slack"]["connection_status"] == "connected"
|
||||||
|
assert restarted_app.state.channels_config["slack"] == {
|
||||||
|
"enabled": True,
|
||||||
|
"bot_token": "xoxb-ui",
|
||||||
|
"app_token": "xapp-ui",
|
||||||
|
}
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
def test_disconnect_connection_revokes_current_user_connection(tmp_path):
|
def test_disconnect_connection_revokes_current_user_connection(tmp_path):
|
||||||
import anyio
|
import anyio
|
||||||
|
|
||||||
|
|||||||
@@ -3287,10 +3287,17 @@ class TestChannelService:
|
|||||||
|
|
||||||
assert service._config == {"telegram": {"enabled": False}}
|
assert service._config == {"telegram": {"enabled": False}}
|
||||||
|
|
||||||
def test_from_app_config_does_not_create_runtime_channels_from_channel_connections(self):
|
def test_from_app_config_does_not_create_runtime_channels_from_channel_connections(
|
||||||
|
self,
|
||||||
|
monkeypatch,
|
||||||
|
tmp_path,
|
||||||
|
):
|
||||||
from app.channels.service import ChannelService
|
from app.channels.service import ChannelService
|
||||||
|
from deerflow.config import paths as paths_module
|
||||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||||
|
|
||||||
|
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
|
||||||
|
monkeypatch.setattr(paths_module, "_paths", None)
|
||||||
app_config = SimpleNamespace(
|
app_config = SimpleNamespace(
|
||||||
model_extra={},
|
model_extra={},
|
||||||
channel_connections=ChannelConnectionsConfig.model_validate(
|
channel_connections=ChannelConnectionsConfig.model_validate(
|
||||||
@@ -3307,10 +3314,26 @@ class TestChannelService:
|
|||||||
|
|
||||||
assert service._config == {}
|
assert service._config == {}
|
||||||
|
|
||||||
def test_from_app_config_preserves_existing_runtime_channels_with_channel_connections_enabled(self):
|
def test_from_app_config_preserves_existing_runtime_channels_with_channel_connections_enabled(
|
||||||
|
self,
|
||||||
|
monkeypatch,
|
||||||
|
tmp_path,
|
||||||
|
):
|
||||||
|
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
|
||||||
from app.channels.service import ChannelService
|
from app.channels.service import ChannelService
|
||||||
|
from deerflow.config import paths as paths_module
|
||||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||||
|
|
||||||
|
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
|
||||||
|
monkeypatch.setattr(paths_module, "_paths", None)
|
||||||
|
ChannelRuntimeConfigStore().set_provider_config(
|
||||||
|
"slack",
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"bot_token": "xoxb-ui",
|
||||||
|
"app_token": "xapp-ui",
|
||||||
|
},
|
||||||
|
)
|
||||||
app_config = SimpleNamespace(
|
app_config = SimpleNamespace(
|
||||||
model_extra={
|
model_extra={
|
||||||
"channels": {
|
"channels": {
|
||||||
@@ -3335,6 +3358,40 @@ class TestChannelService:
|
|||||||
assert service._config["slack"]["app_token"] == "xapp"
|
assert service._config["slack"]["app_token"] == "xapp"
|
||||||
assert service._config["discord"]["bot_token"] == "discord-bot-token"
|
assert service._config["discord"]["bot_token"] == "discord-bot-token"
|
||||||
|
|
||||||
|
def test_from_app_config_loads_persisted_runtime_channel_config(self, monkeypatch, tmp_path):
|
||||||
|
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
|
||||||
|
from app.channels.service import ChannelService
|
||||||
|
from deerflow.config import paths as paths_module
|
||||||
|
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||||
|
|
||||||
|
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
|
||||||
|
monkeypatch.setattr(paths_module, "_paths", None)
|
||||||
|
ChannelRuntimeConfigStore().set_provider_config(
|
||||||
|
"slack",
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"bot_token": "xoxb-ui",
|
||||||
|
"app_token": "xapp-ui",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
app_config = SimpleNamespace(
|
||||||
|
model_extra={},
|
||||||
|
channel_connections=ChannelConnectionsConfig.model_validate(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"slack": {"enabled": True},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
service = ChannelService.from_app_config(app_config)
|
||||||
|
|
||||||
|
assert service._config["slack"] == {
|
||||||
|
"enabled": True,
|
||||||
|
"bot_token": "xoxb-ui",
|
||||||
|
"app_token": "xapp-ui",
|
||||||
|
}
|
||||||
|
|
||||||
def test_connection_repo_is_forwarded_to_manager(self):
|
def test_connection_repo_is_forwarded_to_manager(self):
|
||||||
from app.channels.service import ChannelService
|
from app.channels.service import ChannelService
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import type { AIMessage, Message, Run } from "@langchain/langgraph-sdk";
|
import type { AIMessage, Message, Run } from "@langchain/langgraph-sdk";
|
||||||
import type { ThreadsClient } from "@langchain/langgraph-sdk/client";
|
|
||||||
import { useStream } from "@langchain/langgraph-sdk/react";
|
import { useStream } from "@langchain/langgraph-sdk/react";
|
||||||
import {
|
import {
|
||||||
type QueryClient,
|
type QueryClient,
|
||||||
@@ -24,6 +23,11 @@ import type { UploadedFileInfo } from "../uploads";
|
|||||||
import { promptInputFilePartToFile, uploadFiles } from "../uploads";
|
import { promptInputFilePartToFile, uploadFiles } from "../uploads";
|
||||||
|
|
||||||
import { fetchThreadTokenUsage } from "./api";
|
import { fetchThreadTokenUsage } from "./api";
|
||||||
|
import {
|
||||||
|
buildThreadsSearchQueryOptions,
|
||||||
|
DEFAULT_THREAD_SEARCH_PARAMS,
|
||||||
|
type ThreadSearchParams,
|
||||||
|
} from "./thread-search-query";
|
||||||
import { threadTokenUsageQueryKey } from "./token-usage";
|
import { threadTokenUsageQueryKey } from "./token-usage";
|
||||||
import type {
|
import type {
|
||||||
AgentThread,
|
AgentThread,
|
||||||
@@ -1034,69 +1038,11 @@ export function useThreadHistory(
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function useThreads(
|
export function useThreads(
|
||||||
params: Parameters<ThreadsClient["search"]>[0] = {
|
params: ThreadSearchParams = DEFAULT_THREAD_SEARCH_PARAMS,
|
||||||
limit: 50,
|
|
||||||
sortBy: "updated_at",
|
|
||||||
sortOrder: "desc",
|
|
||||||
select: ["thread_id", "updated_at", "values", "metadata"],
|
|
||||||
},
|
|
||||||
) {
|
) {
|
||||||
const apiClient = getAPIClient();
|
const apiClient = getAPIClient();
|
||||||
return useQuery<AgentThread[]>({
|
return useQuery<AgentThread[]>({
|
||||||
queryKey: ["threads", "search", params],
|
...buildThreadsSearchQueryOptions(apiClient, params),
|
||||||
queryFn: async () => {
|
|
||||||
const maxResults = params.limit;
|
|
||||||
const initialOffset = params.offset ?? 0;
|
|
||||||
const DEFAULT_PAGE_SIZE = 50;
|
|
||||||
|
|
||||||
// Preserve prior semantics: if a non-positive limit is explicitly provided,
|
|
||||||
// delegate to a single search call with the original parameters.
|
|
||||||
if (maxResults !== undefined && maxResults <= 0) {
|
|
||||||
const response =
|
|
||||||
await apiClient.threads.search<AgentThreadState>(params);
|
|
||||||
return response as AgentThread[];
|
|
||||||
}
|
|
||||||
|
|
||||||
const pageSize =
|
|
||||||
typeof maxResults === "number" && maxResults > 0
|
|
||||||
? Math.min(DEFAULT_PAGE_SIZE, maxResults)
|
|
||||||
: DEFAULT_PAGE_SIZE;
|
|
||||||
|
|
||||||
const threads: AgentThread[] = [];
|
|
||||||
let offset = initialOffset;
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
if (typeof maxResults === "number" && threads.length >= maxResults) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
const currentLimit =
|
|
||||||
typeof maxResults === "number"
|
|
||||||
? Math.min(pageSize, maxResults - threads.length)
|
|
||||||
: pageSize;
|
|
||||||
|
|
||||||
if (typeof maxResults === "number" && currentLimit <= 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
const response = (await apiClient.threads.search<AgentThreadState>({
|
|
||||||
...params,
|
|
||||||
limit: currentLimit,
|
|
||||||
offset,
|
|
||||||
})) as AgentThread[];
|
|
||||||
|
|
||||||
threads.push(...response);
|
|
||||||
|
|
||||||
if (response.length < currentLimit) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
offset += response.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
return threads;
|
|
||||||
},
|
|
||||||
refetchOnWindowFocus: false,
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,84 @@
|
|||||||
|
import type { ThreadsClient } from "@langchain/langgraph-sdk/client";
|
||||||
|
|
||||||
|
import type { AgentThread, AgentThreadState } from "./types";
|
||||||
|
|
||||||
|
type ThreadsSearchClient = {
|
||||||
|
threads: {
|
||||||
|
search: ThreadsClient["search"];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ThreadSearchParams = NonNullable<Parameters<ThreadsClient["search"]>[0]>;
|
||||||
|
|
||||||
|
export const DEFAULT_THREAD_SEARCH_PARAMS: ThreadSearchParams = {
|
||||||
|
limit: 50,
|
||||||
|
sortBy: "updated_at",
|
||||||
|
sortOrder: "desc",
|
||||||
|
select: ["thread_id", "updated_at", "values", "metadata"],
|
||||||
|
};
|
||||||
|
|
||||||
|
export const THREAD_SEARCH_REFETCH_INTERVAL_MS = 5000;
|
||||||
|
|
||||||
|
export function buildThreadsSearchQueryOptions(
|
||||||
|
apiClient: ThreadsSearchClient,
|
||||||
|
params: ThreadSearchParams = DEFAULT_THREAD_SEARCH_PARAMS,
|
||||||
|
) {
|
||||||
|
return {
|
||||||
|
queryKey: ["threads", "search", params],
|
||||||
|
queryFn: async () => {
|
||||||
|
const maxResults = params.limit;
|
||||||
|
const initialOffset = params.offset ?? 0;
|
||||||
|
const DEFAULT_PAGE_SIZE = 50;
|
||||||
|
|
||||||
|
// Preserve prior semantics: if a non-positive limit is explicitly provided,
|
||||||
|
// delegate to a single search call with the original parameters.
|
||||||
|
if (maxResults !== undefined && maxResults <= 0) {
|
||||||
|
const response =
|
||||||
|
await apiClient.threads.search<AgentThreadState>(params);
|
||||||
|
return response as AgentThread[];
|
||||||
|
}
|
||||||
|
|
||||||
|
const pageSize =
|
||||||
|
typeof maxResults === "number" && maxResults > 0
|
||||||
|
? Math.min(DEFAULT_PAGE_SIZE, maxResults)
|
||||||
|
: DEFAULT_PAGE_SIZE;
|
||||||
|
|
||||||
|
const threads: AgentThread[] = [];
|
||||||
|
let offset = initialOffset;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (typeof maxResults === "number" && threads.length >= maxResults) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const currentLimit =
|
||||||
|
typeof maxResults === "number"
|
||||||
|
? Math.min(pageSize, maxResults - threads.length)
|
||||||
|
: pageSize;
|
||||||
|
|
||||||
|
if (typeof maxResults === "number" && currentLimit <= 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = (await apiClient.threads.search<AgentThreadState>({
|
||||||
|
...params,
|
||||||
|
limit: currentLimit,
|
||||||
|
offset,
|
||||||
|
})) as AgentThread[];
|
||||||
|
|
||||||
|
threads.push(...response);
|
||||||
|
|
||||||
|
if (response.length < currentLimit) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
offset += response.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
return threads;
|
||||||
|
},
|
||||||
|
refetchInterval: THREAD_SEARCH_REFETCH_INTERVAL_MS,
|
||||||
|
refetchIntervalInBackground: false,
|
||||||
|
refetchOnWindowFocus: false,
|
||||||
|
};
|
||||||
|
}
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
import { expect, test, vi } from "vitest";
|
||||||
|
|
||||||
|
import {
|
||||||
|
buildThreadsSearchQueryOptions,
|
||||||
|
DEFAULT_THREAD_SEARCH_PARAMS,
|
||||||
|
THREAD_SEARCH_REFETCH_INTERVAL_MS,
|
||||||
|
} from "@/core/threads/thread-search-query";
|
||||||
|
|
||||||
|
test("thread search query refreshes so IM-created sessions appear in the sidebar", () => {
|
||||||
|
const search = vi.fn();
|
||||||
|
const options = buildThreadsSearchQueryOptions(
|
||||||
|
{ threads: { search } },
|
||||||
|
DEFAULT_THREAD_SEARCH_PARAMS,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(options.refetchInterval).toBe(THREAD_SEARCH_REFETCH_INTERVAL_MS);
|
||||||
|
expect(options.refetchIntervalInBackground).toBe(false);
|
||||||
|
expect(options.refetchOnWindowFocus).toBe(false);
|
||||||
|
});
|
||||||
Reference in New Issue
Block a user