mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 18:05:58 +00:00
Persist IM runtime config locally
This commit is contained in:
@@ -10,6 +10,11 @@ from typing import Any
|
||||
from fastapi import APIRouter, HTTPException, Request, Response
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.channels.runtime_config_store import (
|
||||
ChannelRuntimeConfigStore,
|
||||
apply_runtime_connection_config,
|
||||
merge_runtime_channel_configs,
|
||||
)
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
@@ -132,11 +137,22 @@ def _get_app_config():
|
||||
return get_app_config()
|
||||
|
||||
|
||||
def _get_runtime_config_store(request: Request) -> ChannelRuntimeConfigStore:
|
||||
store = getattr(request.app.state, "channel_runtime_config_store", None)
|
||||
if isinstance(store, ChannelRuntimeConfigStore):
|
||||
return store
|
||||
store = ChannelRuntimeConfigStore()
|
||||
request.app.state.channel_runtime_config_store = store
|
||||
return store
|
||||
|
||||
|
||||
def _get_channel_connections_config(request: Request) -> ChannelConnectionsConfig:
|
||||
config = getattr(request.app.state, "channel_connections_config", None)
|
||||
if isinstance(config, ChannelConnectionsConfig):
|
||||
return config
|
||||
return _get_app_config().channel_connections
|
||||
if not isinstance(config, ChannelConnectionsConfig):
|
||||
config = _get_app_config().channel_connections
|
||||
config = apply_runtime_connection_config(config, store=_get_runtime_config_store(request))
|
||||
request.app.state.channel_connections_config = config
|
||||
return config
|
||||
|
||||
|
||||
def _get_channels_config(request: Request) -> dict[str, Any]:
|
||||
@@ -147,7 +163,14 @@ def _get_channels_config(request: Request) -> dict[str, Any]:
|
||||
app_config = _get_app_config()
|
||||
extra = app_config.model_extra or {}
|
||||
channels_config = extra.get("channels")
|
||||
return dict(channels_config) if isinstance(channels_config, dict) else {}
|
||||
result = dict(channels_config) if isinstance(channels_config, dict) else {}
|
||||
merge_runtime_channel_configs(
|
||||
result,
|
||||
_get_channel_connections_config(request),
|
||||
store=_get_runtime_config_store(request),
|
||||
)
|
||||
request.app.state.channels_config = result
|
||||
return result
|
||||
|
||||
|
||||
def _get_repository(request: Request, config: ChannelConnectionsConfig) -> ChannelConnectionRepository:
|
||||
@@ -436,6 +459,7 @@ async def configure_channel_provider_runtime(
|
||||
runtime_config[key] = values[key]
|
||||
|
||||
if provider == "telegram":
|
||||
runtime_config["bot_username"] = values["bot_username"]
|
||||
provider_config.bot_username = values["bot_username"]
|
||||
request.app.state.channel_connections_config = config
|
||||
|
||||
@@ -447,4 +471,6 @@ async def configure_channel_provider_runtime(
|
||||
display_name = _PROVIDER_META[provider]["display_name"]
|
||||
raise HTTPException(status_code=400, detail=f"Failed to start {display_name} channel. Check the values and try again.")
|
||||
|
||||
_get_runtime_config_store(request).set_provider_config(provider, runtime_config)
|
||||
|
||||
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
|
||||
|
||||
Reference in New Issue
Block a user