Merge refactor/config-deerflow-context into release/2.0-rc

Cherry-pick PR #2271's config refactor onto release/2.0-rc.
Used 'git merge -X theirs' to auto-resolve content conflicts in favor of
the PR's design (frozen AppConfig + explicit-parameter passing).

Limitations:
- Release-only changes that overlapped with PR's refactor in 119 files
  are NOT preserved — those files reflect PR's version. Follow-up commits
  on this branch will need to re-apply release-only modifications where
  meaningful.
- See PR #2271 for design rationale.
This commit is contained in:
greatmengqi
2026-04-27 18:16:42 +08:00
227 changed files with 6965 additions and 5578 deletions
+3 -1
View File
@@ -375,7 +375,9 @@ class FeishuChannel(Channel):
virtual_path = f"{VIRTUAL_PATH_PREFIX}/uploads/{resolved_target.name}"
try:
sandbox_provider = get_sandbox_provider()
from deerflow.config.app_config import AppConfig
sandbox_provider = get_sandbox_provider(AppConfig.from_file())
sandbox_id = sandbox_provider.acquire(thread_id)
if sandbox_id != "local":
sandbox = sandbox_provider.get(sandbox_id)
-2
View File
@@ -17,8 +17,6 @@ from langgraph_sdk.errors import ConflictError
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.store import ChannelStore
from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token
from app.gateway.internal_auth import create_internal_auth_headers
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__)
+9 -9
View File
@@ -4,13 +4,16 @@ from __future__ import annotations
import logging
import os
from typing import Any
from typing import TYPE_CHECKING, Any
from app.channels.base import Channel
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
from app.channels.message_bus import MessageBus
from app.channels.store import ChannelStore
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
# Channel name → import path for lazy loading
@@ -75,14 +78,11 @@ class ChannelService:
self._running = False
@classmethod
def from_app_config(cls) -> ChannelService:
"""Create a ChannelService from the application config."""
from deerflow.config.app_config import get_app_config
config = get_app_config()
def from_app_config(cls, app_config: AppConfig) -> ChannelService:
"""Create a ChannelService from an explicit application config."""
channels_config = {}
# extra fields are allowed by AppConfig (extra="allow")
extra = config.model_extra or {}
extra = app_config.model_extra or {}
if "channels" in extra:
channels_config = extra["channels"]
return cls(channels_config=channels_config)
@@ -201,12 +201,12 @@ def get_channel_service() -> ChannelService | None:
return _channel_service
async def start_channel_service() -> ChannelService:
async def start_channel_service(app_config: AppConfig) -> ChannelService:
"""Create and start the global ChannelService from app config."""
global _channel_service
if _channel_service is not None:
return _channel_service
_channel_service = ChannelService.from_app_config()
_channel_service = ChannelService.from_app_config(app_config)
await _channel_service.start()
return _channel_service
+11 -16
View File
@@ -28,7 +28,7 @@ from app.gateway.routers import (
threads,
uploads,
)
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
# Configure logging
logging.basicConfig(
@@ -72,18 +72,7 @@ async def _ensure_admin_user(app: FastAPI) -> None:
from deerflow.persistence.engine import get_session_factory
from deerflow.persistence.user.model import UserRow
try:
provider = get_local_provider()
except RuntimeError:
# Auth persistence may not be initialized in some test/boot paths.
# Skip admin migration work rather than failing gateway startup.
logger.warning("Auth persistence not ready; skipping admin bootstrap check")
return
sf = get_session_factory()
if sf is None:
return
provider = get_local_provider()
admin_count = await provider.count_admin_users()
if admin_count == 0:
@@ -95,6 +84,10 @@ async def _ensure_admin_user(app: FastAPI) -> None:
# Admin already exists — run orphan thread migration for any
# LangGraph thread metadata that pre-dates the auth module.
sf = get_session_factory()
if sf is None:
return
async with sf() as session:
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
row = (await session.execute(stmt)).scalar_one_or_none()
@@ -158,9 +151,11 @@ async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application lifespan handler."""
# Load config and check necessary environment variables at startup
try:
get_app_config()
# ``app.state.config`` is the sole source of truth for
# ``Depends(get_config)``. Consumers that want AppConfig must receive
# it as an explicit parameter; there is no ambient singleton.
app.state.config = AppConfig.from_file()
logger.info("Configuration loaded successfully")
except Exception as e:
error_msg = f"Failed to load configuration during gateway startup: {e}"
@@ -181,7 +176,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
try:
from app.channels.service import start_channel_service
channel_service = await start_channel_service()
channel_service = await start_channel_service(app.state.config)
logger.info("Channel service started: %s", channel_service.get_status())
except Exception:
logger.exception("No IM channels configured or channel service failed to start")
+2 -2
View File
@@ -12,12 +12,12 @@ class AuthProvider(ABC):
Returns User if authentication succeeds, None otherwise.
"""
raise NotImplementedError
...
@abstractmethod
async def get_user(self, user_id: str) -> "User | None":
"""Retrieve user by ID."""
raise NotImplementedError
...
# Import User at runtime to avoid circular imports
@@ -35,7 +35,7 @@ class UserRepository(ABC):
Raises:
ValueError: If email already exists
"""
raise NotImplementedError
...
@abstractmethod
async def get_user_by_id(self, user_id: str) -> User | None:
@@ -47,7 +47,7 @@ class UserRepository(ABC):
Returns:
User if found, None otherwise
"""
raise NotImplementedError
...
@abstractmethod
async def get_user_by_email(self, email: str) -> User | None:
@@ -59,7 +59,7 @@ class UserRepository(ABC):
Returns:
User if found, None otherwise
"""
raise NotImplementedError
...
@abstractmethod
async def update_user(self, user: User) -> User:
@@ -76,17 +76,17 @@ class UserRepository(ABC):
a hard failure (not a no-op) so callers cannot mistake a
concurrent-delete race for a successful update.
"""
raise NotImplementedError
...
@abstractmethod
async def count_users(self) -> int:
"""Return total number of registered users."""
raise NotImplementedError
...
@abstractmethod
async def count_admin_users(self) -> int:
"""Return number of users with system_role == 'admin'."""
raise NotImplementedError
...
@abstractmethod
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
@@ -99,4 +99,4 @@ class UserRepository(ABC):
Returns:
User if found, None otherwise
"""
raise NotImplementedError
...
+3 -2
View File
@@ -25,14 +25,15 @@ from deerflow.persistence.user.model import UserRow
async def _run(email: str | None) -> int:
from deerflow.config import get_app_config
from deerflow.config import AppConfig
from deerflow.persistence.engine import (
close_engine,
get_session_factory,
init_engine_from_config,
)
config = get_app_config()
# CLI entry: load config explicitly at the top, pass down through the closure.
config = AppConfig.from_file()
await init_engine_from_config(config.database)
try:
sf = get_session_factory()
+5 -13
View File
@@ -18,7 +18,6 @@ from starlette.types import ASGIApp
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token
from deerflow.runtime.user_context import reset_current_user, set_current_user
# Paths that never require authentication.
@@ -76,12 +75,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
if _is_public(request.url.path):
return await call_next(request)
internal_user = None
if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)):
internal_user = get_internal_user()
# Non-public path: require session cookie
if internal_user is None and not request.cookies.get("access_token"):
if not request.cookies.get("access_token"):
return JSONResponse(
status_code=401,
content={
@@ -105,13 +100,10 @@ class AuthMiddleware(BaseHTTPMiddleware):
# bubble up, so we catch and render it as JSONResponse here.
from app.gateway.deps import get_current_user_from_request
if internal_user is not None:
user = internal_user
else:
try:
user = await get_current_user_from_request(request)
except HTTPException as exc:
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
try:
user = await get_current_user_from_request(request)
except HTTPException as exc:
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
# Stamp both request.state.user (for the contextvar pattern)
# and request.state.auth (so @require_permission's "auth is
+2 -33
View File
@@ -30,9 +30,7 @@ Inspired by LangGraph Auth system: https://github.com/langchain-ai/langgraph/blo
from __future__ import annotations
import functools
import inspect
from collections.abc import Callable
from types import SimpleNamespace
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
from fastapi import HTTPException, Request
@@ -119,15 +117,6 @@ _ALL_PERMISSIONS: list[str] = [
]
def _make_test_request_stub() -> Any:
"""Create a minimal request-like object for direct unit calls.
Used when decorated route handlers are invoked without FastAPI's
request injection. Includes fields accessed by auth helpers.
"""
return SimpleNamespace(state=SimpleNamespace(), cookies={}, _deerflow_test_bypass_auth=True)
async def _authenticate(request: Request) -> AuthContext:
"""Authenticate request and return AuthContext.
@@ -165,17 +154,7 @@ def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
async def wrapper(*args: Any, **kwargs: Any) -> Any:
request = kwargs.get("request")
if request is None:
# Unit tests may call decorated handlers directly without a
# FastAPI Request object. Inject a minimal request stub when
# the wrapped function declares `request`.
if "request" in inspect.signature(func).parameters:
kwargs["request"] = _make_test_request_stub()
else:
raise ValueError("require_auth decorator requires 'request' parameter")
request = kwargs["request"]
if getattr(request, "_deerflow_test_bypass_auth", False):
return await func(*args, **kwargs)
raise ValueError("require_auth decorator requires 'request' parameter")
# Authenticate and set context
auth_context = await _authenticate(request)
@@ -231,17 +210,7 @@ def require_permission(
async def wrapper(*args: Any, **kwargs: Any) -> Any:
request = kwargs.get("request")
if request is None:
# Unit tests may call decorated route handlers directly without
# constructing a FastAPI Request object. Inject a minimal stub
# when the wrapped function declares `request`.
if "request" in inspect.signature(func).parameters:
kwargs["request"] = _make_test_request_stub()
else:
return await func(*args, **kwargs)
request = kwargs["request"]
if getattr(request, "_deerflow_test_bypass_auth", False):
return await func(*args, **kwargs)
raise ValueError("require_permission decorator requires 'request' parameter")
auth: AuthContext = getattr(request.state, "auth", None)
if auth is None:
+38 -24
View File
@@ -10,15 +10,13 @@ from __future__ import annotations
from collections.abc import AsyncGenerator, Callable
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, TypeVar, cast
from typing import TYPE_CHECKING
from fastapi import FastAPI, HTTPException, Request
from langgraph.types import Checkpointer
from deerflow.persistence.feedback import FeedbackRepository
from deerflow.runtime import RunContext, RunManager, StreamBridge
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.runs.store.base import RunStore
from deerflow.config.app_config import AppConfig
from deerflow.runtime import RunContext, RunManager
if TYPE_CHECKING:
from app.gateway.auth.local_provider import LocalAuthProvider
@@ -26,7 +24,17 @@ if TYPE_CHECKING:
from deerflow.persistence.thread_meta.base import ThreadMetaStore
T = TypeVar("T")
def get_config(request: Request) -> AppConfig:
"""FastAPI dependency returning the app-scoped ``AppConfig``.
Reads from ``request.app.state.config`` which is set at startup
(``app.py`` lifespan) and swapped on config reload (``routers/mcp.py``,
``routers/skills.py``).
"""
cfg = getattr(request.app.state, "config", None)
if cfg is None:
raise HTTPException(status_code=503, detail="Configuration not available")
return cfg
@asynccontextmanager
@@ -38,22 +46,24 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
async with langgraph_runtime(app):
yield
"""
from deerflow.config import get_app_config
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
from deerflow.runtime import make_store, make_stream_bridge
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
from deerflow.runtime.events.store import make_run_event_store
async with AsyncExitStack() as stack:
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
# app.state.config is populated earlier in lifespan(); thread it
# explicitly into every provider below.
config = app.state.config
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge(config))
# Initialize persistence engine BEFORE checkpointer so that
# auto-create-database logic runs first (postgres backend).
config = get_app_config()
await init_engine_from_config(config.database)
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
app.state.store = await stack.enter_async_context(make_store())
app.state.checkpointer = await stack.enter_async_context(make_checkpointer(config))
app.state.store = await stack.enter_async_context(make_store(config))
# Initialize repositories — one get_session_factory() call for all.
sf = get_session_factory()
@@ -91,25 +101,25 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
# ---------------------------------------------------------------------------
def _require(attr: str, label: str) -> Callable[[Request], T]:
def _require(attr: str, label: str):
"""Create a FastAPI dependency that returns ``app.state.<attr>`` or 503."""
def dep(request: Request) -> T:
def dep(request: Request):
val = getattr(request.app.state, attr, None)
if val is None:
raise HTTPException(status_code=503, detail=f"{label} not available")
return cast(T, val)
return val
dep.__name__ = dep.__qualname__ = f"get_{attr}"
return dep
get_stream_bridge: Callable[[Request], StreamBridge] = _require("stream_bridge", "Stream bridge")
get_run_manager: Callable[[Request], RunManager] = _require("run_manager", "Run manager")
get_checkpointer: Callable[[Request], Checkpointer] = _require("checkpointer", "Checkpointer")
get_run_event_store: Callable[[Request], RunEventStore] = _require("run_event_store", "Run event store")
get_feedback_repo: Callable[[Request], FeedbackRepository] = _require("feedback_repo", "Feedback")
get_run_store: Callable[[Request], RunStore] = _require("run_store", "Run store")
get_stream_bridge = _require("stream_bridge", "Stream bridge")
get_run_manager = _require("run_manager", "Run manager")
get_checkpointer = _require("checkpointer", "Checkpointer")
get_run_event_store = _require("run_event_store", "Run event store")
get_feedback_repo = _require("feedback_repo", "Feedback")
get_run_store = _require("run_store", "Run store")
def get_store(request: Request):
@@ -128,19 +138,23 @@ def get_thread_store(request: Request) -> ThreadMetaStore:
def get_run_context(request: Request) -> RunContext:
"""Build a :class:`RunContext` from ``app.state`` singletons.
Returns a *base* context with infrastructure dependencies.
Returns a *base* context with infrastructure dependencies. Callers that
need per-run fields (e.g. ``follow_up_to_run_id``) should use
``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it
to :func:`run_agent`.
"""
from deerflow.config import get_app_config
config = get_config(request)
return RunContext(
checkpointer=get_checkpointer(request),
store=get_store(request),
event_store=get_run_event_store(request),
run_events_config=getattr(get_app_config(), "run_events", None),
run_events_config=getattr(config, "run_events", None),
thread_store=get_thread_store(request),
app_config=config,
)
# ---------------------------------------------------------------------------
# Auth helpers (used by authz.py and auth middleware)
# ---------------------------------------------------------------------------
+20 -19
View File
@@ -5,11 +5,12 @@ import re
import shutil
import yaml
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from deerflow.config.agents_api_config import get_agents_api_config
from app.gateway.deps import get_config
from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import get_paths
logger = logging.getLogger(__name__)
@@ -77,9 +78,9 @@ def _normalize_agent_name(name: str) -> str:
return name.lower()
def _require_agents_api_enabled() -> None:
def _require_agents_api_enabled(app_config: AppConfig) -> None:
"""Reject access unless the custom-agent management API is explicitly enabled."""
if not get_agents_api_config().enabled:
if not app_config.agents_api.enabled:
raise HTTPException(
status_code=403,
detail=("Custom-agent management API is disabled. Set agents_api.enabled=true to expose agent and user-profile routes over HTTP."),
@@ -108,13 +109,13 @@ def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False
summary="List Custom Agents",
description="List all custom agents available in the agents directory, including their soul content.",
)
async def list_agents() -> AgentsListResponse:
async def list_agents(app_config: AppConfig = Depends(get_config)) -> AgentsListResponse:
"""List all custom agents.
Returns:
List of all custom agents with their metadata and soul content.
"""
_require_agents_api_enabled()
_require_agents_api_enabled(app_config)
try:
agents = list_custom_agents()
@@ -141,7 +142,7 @@ async def check_agent_name(name: str) -> dict:
Raises:
HTTPException: 422 if the name is invalid.
"""
_require_agents_api_enabled()
_require_agents_api_enabled(app_config)
_validate_agent_name(name)
normalized = _normalize_agent_name(name)
available = not get_paths().agent_dir(normalized).exists()
@@ -154,7 +155,7 @@ async def check_agent_name(name: str) -> dict:
summary="Get Custom Agent",
description="Retrieve details and SOUL.md content for a specific custom agent.",
)
async def get_agent(name: str) -> AgentResponse:
async def get_agent(name: str, app_config: AppConfig = Depends(get_config)) -> AgentResponse:
"""Get a specific custom agent by name.
Args:
@@ -166,7 +167,7 @@ async def get_agent(name: str) -> AgentResponse:
Raises:
HTTPException: 404 if agent not found.
"""
_require_agents_api_enabled()
_require_agents_api_enabled(app_config)
_validate_agent_name(name)
name = _normalize_agent_name(name)
@@ -187,7 +188,7 @@ async def get_agent(name: str) -> AgentResponse:
summary="Create Custom Agent",
description="Create a new custom agent with its config and SOUL.md.",
)
async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
async def create_agent_endpoint(request: AgentCreateRequest, app_config: AppConfig = Depends(get_config)) -> AgentResponse:
"""Create a new custom agent.
Args:
@@ -199,7 +200,7 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
Raises:
HTTPException: 409 if agent already exists, 422 if name is invalid.
"""
_require_agents_api_enabled()
_require_agents_api_enabled(app_config)
_validate_agent_name(request.name)
normalized_name = _normalize_agent_name(request.name)
@@ -251,7 +252,7 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
summary="Update Custom Agent",
description="Update an existing custom agent's config and/or SOUL.md.",
)
async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
async def update_agent(name: str, request: AgentUpdateRequest, app_config: AppConfig = Depends(get_config)) -> AgentResponse:
"""Update an existing custom agent.
Args:
@@ -264,7 +265,7 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
Raises:
HTTPException: 404 if agent not found.
"""
_require_agents_api_enabled()
_require_agents_api_enabled(app_config)
_validate_agent_name(name)
name = _normalize_agent_name(name)
@@ -342,13 +343,13 @@ class UserProfileUpdateRequest(BaseModel):
summary="Get User Profile",
description="Read the global USER.md file that is injected into all custom agents.",
)
async def get_user_profile() -> UserProfileResponse:
async def get_user_profile(app_config: AppConfig = Depends(get_config)) -> UserProfileResponse:
"""Return the current USER.md content.
Returns:
UserProfileResponse with content=None if USER.md does not exist yet.
"""
_require_agents_api_enabled()
_require_agents_api_enabled(app_config)
try:
user_md_path = get_paths().user_md_file
@@ -367,7 +368,7 @@ async def get_user_profile() -> UserProfileResponse:
summary="Update User Profile",
description="Write the global USER.md file that is injected into all custom agents.",
)
async def update_user_profile(request: UserProfileUpdateRequest) -> UserProfileResponse:
async def update_user_profile(request: UserProfileUpdateRequest, app_config: AppConfig = Depends(get_config)) -> UserProfileResponse:
"""Create or overwrite the global USER.md.
Args:
@@ -376,7 +377,7 @@ async def update_user_profile(request: UserProfileUpdateRequest) -> UserProfileR
Returns:
UserProfileResponse with the saved content.
"""
_require_agents_api_enabled()
_require_agents_api_enabled(app_config)
try:
paths = get_paths()
@@ -395,7 +396,7 @@ async def update_user_profile(request: UserProfileUpdateRequest) -> UserProfileR
summary="Delete Custom Agent",
description="Delete a custom agent and all its files (config, SOUL.md, memory).",
)
async def delete_agent(name: str) -> None:
async def delete_agent(name: str, app_config: AppConfig = Depends(get_config)) -> None:
"""Delete a custom agent.
Args:
@@ -404,7 +405,7 @@ async def delete_agent(name: str) -> None:
Raises:
HTTPException: 404 if agent not found.
"""
_require_agents_api_enabled()
_require_agents_api_enabled(app_config)
_validate_agent_name(name)
name = _normalize_agent_name(name)
+20 -12
View File
@@ -3,10 +3,12 @@ import logging
from pathlib import Path
from typing import Literal
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, Field
from deerflow.config.extensions_config import ExtensionsConfig, get_extensions_config, reload_extensions_config
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["mcp"])
@@ -69,7 +71,7 @@ class McpConfigUpdateRequest(BaseModel):
summary="Get MCP Configuration",
description="Retrieve the current Model Context Protocol (MCP) server configurations.",
)
async def get_mcp_configuration() -> McpConfigResponse:
async def get_mcp_configuration(config: AppConfig = Depends(get_config)) -> McpConfigResponse:
"""Get the current MCP configuration.
Returns:
@@ -90,9 +92,9 @@ async def get_mcp_configuration() -> McpConfigResponse:
}
```
"""
config = get_extensions_config()
ext = config.extensions
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in config.mcp_servers.items()})
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in ext.mcp_servers.items()})
@router.put(
@@ -101,7 +103,11 @@ async def get_mcp_configuration() -> McpConfigResponse:
summary="Update MCP Configuration",
description="Update Model Context Protocol (MCP) server configurations and save to file.",
)
async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfigResponse:
async def update_mcp_configuration(
request: McpConfigUpdateRequest,
http_request: Request,
config: AppConfig = Depends(get_config),
) -> McpConfigResponse:
"""Update the MCP configuration.
This will:
@@ -142,13 +148,13 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
config_path = Path.cwd().parent / "extensions_config.json"
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
# Load current config to preserve skills configuration
current_config = get_extensions_config()
# Use injected config to preserve skills configuration
current_ext = config.extensions
# Convert request to dict format for JSON serialization
config_data = {
"mcpServers": {name: server.model_dump() for name, server in request.mcp_servers.items()},
"skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()},
"skills": {name: {"enabled": skill.enabled} for name, skill in current_ext.skills.items()},
}
# Write the configuration to file
@@ -160,9 +166,11 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
# NOTE: No need to reload/reset cache here - LangGraph Server (separate process)
# will detect config file changes via mtime and reinitialize MCP tools automatically
# Reload the configuration and update the global cache
reloaded_config = reload_extensions_config()
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded_config.mcp_servers.items()})
# Reload the configuration and swap ``app.state.config`` so subsequent
# ``Depends(get_config)`` calls see the refreshed value.
reloaded = AppConfig.from_file()
http_request.app.state.config = reloaded
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded.extensions.mcp_servers.items()})
except Exception as e:
logger.error(f"Failed to update MCP configuration: {e}", exc_info=True)
+28 -21
View File
@@ -1,8 +1,9 @@
"""Memory API router for retrieving and managing global memory data."""
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from app.gateway.deps import get_config
from deerflow.agents.memory.updater import (
clear_memory_data,
create_memory_fact,
@@ -12,7 +13,7 @@ from deerflow.agents.memory.updater import (
reload_memory_data,
update_memory_fact,
)
from deerflow.config.memory_config import get_memory_config
from deerflow.config.app_config import AppConfig
from deerflow.runtime.user_context import get_effective_user_id
router = APIRouter(prefix="/api", tags=["memory"])
@@ -114,7 +115,7 @@ class MemoryStatusResponse(BaseModel):
summary="Get Memory Data",
description="Retrieve the current global memory data including user context, history, and facts.",
)
async def get_memory() -> MemoryResponse:
async def get_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Get the current global memory data.
Returns:
@@ -148,7 +149,7 @@ async def get_memory() -> MemoryResponse:
}
```
"""
memory_data = get_memory_data(user_id=get_effective_user_id())
memory_data = get_memory_data(app_config.memory, user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
@@ -159,7 +160,7 @@ async def get_memory() -> MemoryResponse:
summary="Reload Memory Data",
description="Reload memory data from the storage file, refreshing the in-memory cache.",
)
async def reload_memory() -> MemoryResponse:
async def reload_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Reload memory data from file.
This forces a reload of the memory data from the storage file,
@@ -168,7 +169,7 @@ async def reload_memory() -> MemoryResponse:
Returns:
The reloaded memory data.
"""
memory_data = reload_memory_data(user_id=get_effective_user_id())
memory_data = reload_memory_data(app_config.memory, user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
@@ -179,10 +180,10 @@ async def reload_memory() -> MemoryResponse:
summary="Clear All Memory Data",
description="Delete all saved memory data and reset the memory structure to an empty state.",
)
async def clear_memory() -> MemoryResponse:
async def clear_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Clear all persisted memory data."""
try:
memory_data = clear_memory_data(user_id=get_effective_user_id())
memory_data = clear_memory_data(app_config.memory, user_id=get_effective_user_id())
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
@@ -196,10 +197,11 @@ async def clear_memory() -> MemoryResponse:
summary="Create Memory Fact",
description="Create a single saved memory fact manually.",
)
async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryResponse:
async def create_memory_fact_endpoint(request: FactCreateRequest, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Create a single fact manually."""
try:
memory_data = create_memory_fact(
app_config.memory,
content=request.content,
category=request.category,
confidence=request.confidence,
@@ -220,10 +222,10 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
summary="Delete Memory Fact",
description="Delete a single saved memory fact by its fact id.",
)
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
async def delete_memory_fact_endpoint(fact_id: str, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Delete a single fact from memory by fact id."""
try:
memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id())
memory_data = delete_memory_fact(app_config.memory, fact_id, user_id=get_effective_user_id())
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
except OSError as exc:
@@ -239,10 +241,11 @@ async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
summary="Patch Memory Fact",
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
)
async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -> MemoryResponse:
async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Partially update a single fact manually."""
try:
memory_data = update_memory_fact(
app_config.memory,
fact_id=fact_id,
content=request.content,
category=request.category,
@@ -266,9 +269,9 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
summary="Export Memory Data",
description="Export the current global memory data as JSON for backup or transfer.",
)
async def export_memory() -> MemoryResponse:
async def export_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Export the current memory data."""
memory_data = get_memory_data(user_id=get_effective_user_id())
memory_data = get_memory_data(app_config.memory, user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
@@ -279,10 +282,10 @@ async def export_memory() -> MemoryResponse:
summary="Import Memory Data",
description="Import and overwrite the current global memory data from a JSON payload.",
)
async def import_memory(request: MemoryResponse) -> MemoryResponse:
async def import_memory(request: MemoryResponse, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Import and persist memory data."""
try:
memory_data = import_memory_data(request.model_dump(), user_id=get_effective_user_id())
memory_data = import_memory_data(app_config.memory, request.model_dump(), user_id=get_effective_user_id())
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
@@ -295,7 +298,9 @@ async def import_memory(request: MemoryResponse) -> MemoryResponse:
summary="Get Memory Configuration",
description="Retrieve the current memory system configuration.",
)
async def get_memory_config_endpoint() -> MemoryConfigResponse:
async def get_memory_config_endpoint(
app_config: AppConfig = Depends(get_config),
) -> MemoryConfigResponse:
"""Get the memory system configuration.
Returns:
@@ -314,7 +319,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
}
```
"""
config = get_memory_config()
config = app_config.memory
return MemoryConfigResponse(
enabled=config.enabled,
storage_path=config.storage_path,
@@ -333,14 +338,16 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
summary="Get Memory Status",
description="Retrieve both memory configuration and current data in a single request.",
)
async def get_memory_status() -> MemoryStatusResponse:
async def get_memory_status(
app_config: AppConfig = Depends(get_config),
) -> MemoryStatusResponse:
"""Get the memory system status including configuration and data.
Returns:
Combined memory configuration and current data.
"""
config = get_memory_config()
memory_data = get_memory_data(user_id=get_effective_user_id())
config = app_config.memory
memory_data = get_memory_data(config, user_id=get_effective_user_id())
return MemoryStatusResponse(
config=MemoryConfigResponse(
+5 -6
View File
@@ -1,7 +1,8 @@
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from deerflow.config import get_app_config
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
router = APIRouter(prefix="/api", tags=["models"])
@@ -36,7 +37,7 @@ class ModelsListResponse(BaseModel):
summary="List All Models",
description="Retrieve a list of all available AI models configured in the system.",
)
async def list_models() -> ModelsListResponse:
async def list_models(config: AppConfig = Depends(get_config)) -> ModelsListResponse:
"""List all available models from configuration.
Returns model information suitable for frontend display,
@@ -72,7 +73,6 @@ async def list_models() -> ModelsListResponse:
}
```
"""
config = get_app_config()
models = [
ModelResponse(
name=model.name,
@@ -96,7 +96,7 @@ async def list_models() -> ModelsListResponse:
summary="Get Model Details",
description="Retrieve detailed information about a specific AI model by its name.",
)
async def get_model(model_name: str) -> ModelResponse:
async def get_model(model_name: str, config: AppConfig = Depends(get_config)) -> ModelResponse:
"""Get a specific model by name.
Args:
@@ -118,7 +118,6 @@ async def get_model(model_name: str) -> ModelResponse:
}
```
"""
config = get_app_config()
model = config.get_model_config(model_name)
if model is None:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
+1 -2
View File
@@ -123,8 +123,7 @@ async def run_messages(
run = await _resolve_run(run_id, request)
event_store = get_run_event_store(request)
rows = await event_store.list_messages_by_run(
run["thread_id"],
run_id,
run["thread_id"], run_id,
limit=limit + 1,
before_seq=before_seq,
after_seq=after_seq,
+69 -45
View File
@@ -4,12 +4,14 @@ import logging
import shutil
from pathlib import Path
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, Field
from app.gateway.deps import get_config
from app.gateway.path_utils import resolve_thread_virtual_path
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.skills import Skill, load_skills
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
from deerflow.skills.manager import (
@@ -101,9 +103,9 @@ def _skill_to_response(skill: Skill) -> SkillResponse:
summary="List All Skills",
description="Retrieve a list of all available skills from both public and custom directories.",
)
async def list_skills() -> SkillsListResponse:
async def list_skills(app_config: AppConfig = Depends(get_config)) -> SkillsListResponse:
try:
skills = load_skills(enabled_only=False)
skills = load_skills(app_config, enabled_only=False)
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e:
logger.error(f"Failed to load skills: {e}", exc_info=True)
@@ -116,11 +118,11 @@ async def list_skills() -> SkillsListResponse:
summary="Install Skill",
description="Install a skill from a .skill file (ZIP archive) located in the thread's user-data directory.",
)
async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
async def install_skill(request: SkillInstallRequest, app_config: AppConfig = Depends(get_config)) -> SkillInstallResponse:
try:
skill_file_path = resolve_thread_virtual_path(request.thread_id, request.path)
result = install_skill_from_archive(skill_file_path)
await refresh_skills_system_prompt_cache_async()
await refresh_skills_system_prompt_cache_async(app_config)
return SkillInstallResponse(**result)
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
@@ -136,9 +138,9 @@ async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
@router.get("/skills/custom", response_model=SkillsListResponse, summary="List Custom Skills")
async def list_custom_skills() -> SkillsListResponse:
async def list_custom_skills(app_config: AppConfig = Depends(get_config)) -> SkillsListResponse:
try:
skills = [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"]
skills = [skill for skill in load_skills(app_config, enabled_only=False) if skill.category == "custom"]
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e:
logger.error("Failed to list custom skills: %s", e, exc_info=True)
@@ -146,13 +148,13 @@ async def list_custom_skills() -> SkillsListResponse:
@router.get("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Get Custom Skill Content")
async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
async def get_custom_skill(skill_name: str, app_config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
try:
skills = load_skills(enabled_only=False)
skills = load_skills(app_config, enabled_only=False)
skill = next((s for s in skills if s.name == skill_name and s.category == "custom"), None)
if skill is None:
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name))
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name, app_config))
except HTTPException:
raise
except Exception as e:
@@ -161,14 +163,18 @@ async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
@router.put("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Edit Custom Skill")
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest) -> CustomSkillContentResponse:
async def update_custom_skill(
skill_name: str,
request: CustomSkillUpdateRequest,
app_config: AppConfig = Depends(get_config),
) -> CustomSkillContentResponse:
try:
ensure_custom_skill_is_editable(skill_name)
ensure_custom_skill_is_editable(skill_name, app_config)
validate_skill_markdown_content(skill_name, request.content)
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/SKILL.md")
scan = await scan_skill_content(app_config, request.content, executable=False, location=f"{skill_name}/SKILL.md")
if scan.decision == "block":
raise HTTPException(status_code=400, detail=f"Security scan blocked the edit: {scan.reason}")
skill_file = get_custom_skill_dir(skill_name) / "SKILL.md"
skill_file = get_custom_skill_dir(skill_name, app_config) / "SKILL.md"
prev_content = skill_file.read_text(encoding="utf-8")
atomic_write(skill_file, request.content)
append_history(
@@ -182,9 +188,10 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
"new_content": request.content,
"scanner": {"decision": scan.decision, "reason": scan.reason},
},
app_config,
)
await refresh_skills_system_prompt_cache_async()
return await get_custom_skill(skill_name)
await refresh_skills_system_prompt_cache_async(app_config)
return await get_custom_skill(skill_name, app_config)
except HTTPException:
raise
except FileNotFoundError as e:
@@ -197,11 +204,11 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
@router.delete("/skills/custom/{skill_name}", summary="Delete Custom Skill")
async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
async def delete_custom_skill(skill_name: str, app_config: AppConfig = Depends(get_config)) -> dict[str, bool]:
try:
ensure_custom_skill_is_editable(skill_name)
skill_dir = get_custom_skill_dir(skill_name)
prev_content = read_custom_skill_content(skill_name)
ensure_custom_skill_is_editable(skill_name, app_config)
skill_dir = get_custom_skill_dir(skill_name, app_config)
prev_content = read_custom_skill_content(skill_name, app_config)
try:
append_history(
skill_name,
@@ -214,13 +221,14 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
"new_content": None,
"scanner": {"decision": "allow", "reason": "Deletion requested."},
},
app_config,
)
except OSError as e:
if not isinstance(e, PermissionError) and e.errno not in {errno.EACCES, errno.EPERM, errno.EROFS}:
raise
logger.warning("Skipping delete history write for custom skill %s due to readonly/permission failure; continuing with skill directory removal: %s", skill_name, e)
shutil.rmtree(skill_dir)
await refresh_skills_system_prompt_cache_async()
await refresh_skills_system_prompt_cache_async(app_config)
return {"success": True}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
@@ -232,11 +240,11 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
@router.get("/skills/custom/{skill_name}/history", response_model=CustomSkillHistoryResponse, summary="Get Custom Skill History")
async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryResponse:
async def get_custom_skill_history(skill_name: str, app_config: AppConfig = Depends(get_config)) -> CustomSkillHistoryResponse:
try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
if not custom_skill_exists(skill_name, app_config) and not get_skill_history_file(skill_name, app_config).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillHistoryResponse(history=read_history(skill_name))
return CustomSkillHistoryResponse(history=read_history(skill_name, app_config))
except HTTPException:
raise
except Exception as e:
@@ -245,11 +253,15 @@ async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryRespons
@router.post("/skills/custom/{skill_name}/rollback", response_model=CustomSkillContentResponse, summary="Rollback Custom Skill")
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest) -> CustomSkillContentResponse:
async def rollback_custom_skill(
skill_name: str,
request: SkillRollbackRequest,
app_config: AppConfig = Depends(get_config),
) -> CustomSkillContentResponse:
try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
if not custom_skill_exists(skill_name, app_config) and not get_skill_history_file(skill_name, app_config).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
history = read_history(skill_name)
history = read_history(skill_name, app_config)
if not history:
raise HTTPException(status_code=400, detail=f"Custom skill '{skill_name}' has no history")
record = history[request.history_index]
@@ -257,8 +269,8 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
if target_content is None:
raise HTTPException(status_code=400, detail="Selected history entry has no previous content to roll back to")
validate_skill_markdown_content(skill_name, target_content)
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/SKILL.md")
skill_file = get_custom_skill_file(skill_name)
scan = await scan_skill_content(app_config, target_content, executable=False, location=f"{skill_name}/SKILL.md")
skill_file = get_custom_skill_file(skill_name, app_config)
current_content = skill_file.read_text(encoding="utf-8") if skill_file.exists() else None
history_entry = {
"action": "rollback",
@@ -271,12 +283,12 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
"scanner": {"decision": scan.decision, "reason": scan.reason},
}
if scan.decision == "block":
append_history(skill_name, history_entry)
append_history(skill_name, history_entry, app_config)
raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}")
atomic_write(skill_file, target_content)
append_history(skill_name, history_entry)
await refresh_skills_system_prompt_cache_async()
return await get_custom_skill(skill_name)
append_history(skill_name, history_entry, app_config)
await refresh_skills_system_prompt_cache_async(app_config)
return await get_custom_skill(skill_name, app_config)
except HTTPException:
raise
except IndexError:
@@ -296,9 +308,9 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
summary="Get Skill Details",
description="Retrieve detailed information about a specific skill by its name.",
)
async def get_skill(skill_name: str) -> SkillResponse:
async def get_skill(skill_name: str, app_config: AppConfig = Depends(get_config)) -> SkillResponse:
try:
skills = load_skills(enabled_only=False)
skills = load_skills(app_config, enabled_only=False)
skill = next((s for s in skills if s.name == skill_name), None)
if skill is None:
@@ -318,9 +330,14 @@ async def get_skill(skill_name: str) -> SkillResponse:
summary="Update Skill",
description="Update a skill's enabled status by modifying the extensions_config.json file.",
)
async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillResponse:
async def update_skill(
skill_name: str,
request: SkillUpdateRequest,
http_request: Request,
app_config: AppConfig = Depends(get_config),
) -> SkillResponse:
try:
skills = load_skills(enabled_only=False)
skills = load_skills(app_config, enabled_only=False)
skill = next((s for s in skills if s.name == skill_name), None)
if skill is None:
@@ -331,22 +348,29 @@ async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillRes
config_path = Path.cwd().parent / "extensions_config.json"
logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
extensions_config = get_extensions_config()
extensions_config.skills[skill_name] = SkillStateConfig(enabled=request.enabled)
# Do not mutate the frozen AppConfig in place. Compose the new skills
# state in a fresh dict, write to disk, and reload AppConfig below so
# every subsequent Depends(get_config) sees the refreshed snapshot.
ext = app_config.extensions
updated_skills = {name: {"enabled": skill_config.enabled} for name, skill_config in ext.skills.items()}
updated_skills[skill_name] = {"enabled": request.enabled}
config_data = {
"mcpServers": {name: server.model_dump() for name, server in extensions_config.mcp_servers.items()},
"skills": {name: {"enabled": skill_config.enabled} for name, skill_config in extensions_config.skills.items()},
"mcpServers": {name: server.model_dump() for name, server in ext.mcp_servers.items()},
"skills": updated_skills,
}
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2)
logger.info(f"Skills configuration updated and saved to: {config_path}")
reload_extensions_config()
await refresh_skills_system_prompt_cache_async()
# Reload AppConfig and swap ``app.state.config`` so subsequent
# ``Depends(get_config)`` sees the refreshed value.
reloaded = AppConfig.from_file()
http_request.app.state.config = reloaded
await refresh_skills_system_prompt_cache_async(reloaded)
skills = load_skills(enabled_only=False)
skills = load_skills(reloaded, enabled_only=False)
updated_skill = next((s for s in skills if s.name == skill_name), None)
if updated_skill is None:
+5 -3
View File
@@ -1,11 +1,13 @@
import json
import logging
from fastapi import APIRouter, Request
from fastapi import APIRouter, Depends, Request
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@@ -100,7 +102,7 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str:
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
)
@require_permission("threads", "read", owner_check=True)
async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse:
async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request, app_config: AppConfig = Depends(get_config)) -> SuggestionsResponse:
if not body.messages:
return SuggestionsResponse(suggestions=[])
@@ -122,7 +124,7 @@ async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
try:
model = create_chat_model(name=body.model_name, thinking_enabled=False)
model = create_chat_model(name=body.model_name, thinking_enabled=False, app_config=app_config)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)], config={"run_name": "suggest_agent"})
raw = _extract_response_text(response.content)
suggestions = _parse_json_string_list(raw) or []
+7 -11
View File
@@ -54,6 +54,7 @@ class RunCreateRequest(BaseModel):
after_seconds: float | None = Field(default=None, description="Delayed execution")
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
follow_up_to_run_id: str | None = Field(default=None, description="Run ID this message follows up on. Auto-detected from latest successful run if not provided.")
class RunResponse(BaseModel):
@@ -311,15 +312,11 @@ async def list_thread_messages(
if i in last_ai_indices:
run_id = msg["run_id"]
fb = feedback_map.get(run_id)
msg["feedback"] = (
{
"feedback_id": fb["feedback_id"],
"rating": fb["rating"],
"comment": fb.get("comment"),
}
if fb
else None
)
msg["feedback"] = {
"feedback_id": fb["feedback_id"],
"rating": fb["rating"],
"comment": fb.get("comment"),
} if fb else None
else:
msg["feedback"] = None
@@ -342,8 +339,7 @@ async def list_run_messages(
"""
event_store = get_run_event_store(request)
rows = await event_store.list_messages_by_run(
thread_id,
run_id,
thread_id, run_id,
limit=limit + 1,
before_seq=before_seq,
after_seq=after_seq,
+1
View File
@@ -13,6 +13,7 @@ matching the LangGraph Platform wire format expected by the
from __future__ import annotations
import logging
import re
import time
import uuid
from typing import Any
+11 -10
View File
@@ -4,11 +4,12 @@ import logging
import os
import stat
from fastapi import APIRouter, File, HTTPException, Request, UploadFile
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from pydantic import BaseModel
from app.gateway.authz import require_permission
from deerflow.config.app_config import get_app_config
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
@@ -60,23 +61,22 @@ def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool:
return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False))
def _get_uploads_config_value(key: str, default: object) -> object:
def _get_uploads_config_value(app_config: AppConfig, key: str, default: object) -> object:
"""Read a value from the uploads config, supporting dict and attribute access."""
cfg = get_app_config()
uploads_cfg = getattr(cfg, "uploads", None)
uploads_cfg = getattr(app_config, "uploads", None)
if isinstance(uploads_cfg, dict):
return uploads_cfg.get(key, default)
return getattr(uploads_cfg, key, default)
def _auto_convert_documents_enabled() -> bool:
def _auto_convert_documents_enabled(app_config: AppConfig) -> bool:
"""Return whether automatic host-side document conversion is enabled.
The secure default is disabled unless an operator explicitly opts in via
uploads.auto_convert_documents in config.yaml.
"""
try:
raw = _get_uploads_config_value("auto_convert_documents", False)
raw = _get_uploads_config_value(app_config, "auto_convert_documents", False)
if isinstance(raw, str):
return raw.strip().lower() in {"1", "true", "yes", "on"}
return bool(raw)
@@ -85,11 +85,12 @@ def _auto_convert_documents_enabled() -> bool:
@router.post("", response_model=UploadResponse)
@require_permission("threads", "write", owner_check=True, require_existing=False)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def upload_files(
thread_id: str,
request: Request,
files: list[UploadFile] = File(...),
app_config: AppConfig = Depends(get_config),
) -> UploadResponse:
"""Upload multiple files to a thread's uploads directory."""
if not files:
@@ -102,13 +103,13 @@ async def upload_files(
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
uploaded_files = []
sandbox_provider = get_sandbox_provider()
sandbox_provider = get_sandbox_provider(app_config)
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
sandbox = None
if sync_to_sandbox:
sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id)
auto_convert_documents = _auto_convert_documents_enabled()
auto_convert_documents = _auto_convert_documents_enabled(app_config)
for file in files:
if not file.filename:
+18 -1
View File
@@ -8,6 +8,7 @@ frames, and consuming stream bridge events. Router modules
from __future__ import annotations
import asyncio
import dataclasses
import json
import logging
import re
@@ -17,7 +18,7 @@ from typing import Any
from fastapi import HTTPException, Request
from langchain_core.messages import HumanMessage
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
from app.gateway.deps import get_run_context, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.utils import sanitize_log_param
from deerflow.runtime import (
END_SENTINEL,
@@ -211,6 +212,21 @@ async def start_run(
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
# Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run
follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None)
if follow_up_to_run_id is None:
run_store = get_run_store(request)
try:
recent_runs = await run_store.list_by_thread(thread_id, limit=1)
if recent_runs and recent_runs[0].get("status") == "success":
follow_up_to_run_id = recent_runs[0]["run_id"]
except Exception:
pass # Don't block run creation
# Enrich base context with per-run field
if follow_up_to_run_id:
run_ctx = dataclasses.replace(run_ctx, follow_up_to_run_id=follow_up_to_run_id)
try:
record = await run_mgr.create_or_reject(
thread_id,
@@ -219,6 +235,7 @@ async def start_run(
metadata=body.metadata or {},
kwargs={"input": body.input, "config": body.config},
multitask_strategy=body.multitask_strategy,
follow_up_to_run_id=follow_up_to_run_id,
)
except ConflictError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc