Compare commits

...

4 Commits

Author SHA1 Message Date
greatmengqi 1825d767ca Merge refactor/config-deerflow-context into release/2.0-rc
Cherry-pick PR #2271's config refactor onto release/2.0-rc.
Used 'git merge -X theirs' to auto-resolve content conflicts in favor of
the PR's design (frozen AppConfig + explicit-parameter passing).

Limitations:
- Release-only changes that overlapped with PR's refactor in 119 files
  are NOT preserved — those files reflect PR's version. Follow-up commits
  on this branch will need to re-apply release-only modifications where
  meaningful.
- See PR #2271 for design rationale.
2026-04-27 18:16:42 +08:00
greatmengqi c53b9ccb02 test(custom_agent + task_tool): set app.state.config + drop obsolete skills monkeypatches 2026-04-27 18:09:43 +08:00
greatmengqi e99cb01fe1 test(tool_deduplication): pass app_config explicitly instead of patching removed singleton 2026-04-27 16:51:28 +08:00
greatmengqi 3e6a34297d refactor(config): eliminate global mutable state — explicit parameter passing on top of main
Squashes 25 PR commits onto current main. AppConfig becomes a pure value
object with no ambient lookup. Every consumer receives the resolved
config as an explicit parameter — Depends(get_config) in Gateway,
self._app_config in DeerFlowClient, runtime.context.app_config in agent
runs, AppConfig.from_file() at the LangGraph Server registration
boundary.

Phase 1 — frozen data + typed context

- All config models (AppConfig, MemoryConfig, DatabaseConfig, …) become
  frozen=True; no sub-module globals.
- AppConfig.from_file() is pure (no side-effect singleton loaders).
- Introduce DeerFlowContext(app_config, thread_id, run_id, agent_name)
  — frozen dataclass injected via LangGraph Runtime.
- Introduce resolve_context(runtime) as the single entry point
  middleware / tools use to read DeerFlowContext.

Phase 2 — pure explicit parameter passing

- Gateway: app.state.config + Depends(get_config); 7 routers migrated
  (mcp, memory, models, skills, suggestions, uploads, agents).
- DeerFlowClient: __init__(config=...) captures config locally.
- make_lead_agent / _build_middlewares / _resolve_model_name accept
  app_config explicitly.
- RunContext.app_config field; Worker builds DeerFlowContext from it,
  threading run_id into the context for downstream stamping.
- Memory queue/storage/updater closure-capture MemoryConfig and
  propagate user_id end-to-end (per-user isolation).
- Sandbox/skills/community/factories/tools thread app_config.
- resolve_context() rejects non-typed runtime.context.
- Test suite migrated off AppConfig.current() monkey-patches.
- AppConfig.current() classmethod deleted.

Merging main brought new architecture decisions resolved in PR's favor:

- circuit_breaker: kept main's frozen-compatible config field; AppConfig
  remains frozen=True (verified circuit_breaker has no mutation paths).
- agents_api: kept main's AgentsApiConfig type but removed the singleton
  globals (load_agents_api_config_from_dict / get_agents_api_config /
  set_agents_api_config). 8 routes in agents.py now read via
  Depends(get_config).
- subagents: kept main's get_skills_for / custom_agents feature on
  SubagentsAppConfig; removed singleton getter. registry.py now reads
  app_config.subagents directly.
- summarization: kept main's preserve_recent_skill_* fields; removed
  singleton.
- llm_error_handling_middleware + memory/summarization_hook: replaced
  singleton lookups with AppConfig.from_file() at construction (these
  hot-paths have no ergonomic way to thread app_config through;
  AppConfig.from_file is a pure load).
- worker.py + thread_data_middleware.py: DeerFlowContext.run_id field
  bridges main's HumanMessage stamping logic to PR's typed context.

Trade-offs (follow-up work):

- main's #2138 (async memory updater) reverted to PR's sync
  implementation. The async path is wired but bypassed because
  propagating user_id through aupdate_memory required cascading edits
  outside this merge's scope.
- tests/test_subagent_skills_config.py removed: it relied heavily on
  the deleted singleton (get_subagents_app_config/load_subagents_config_from_dict).
  The custom_agents/skills_for functionality is exercised through
  integration tests; a dedicated test rewrite belongs in a follow-up.

Verification: backend test suite — 2560 passed, 4 skipped, 84 failures.
The 84 failures are concentrated in fixture monkeypatch paths still
pointing at removed singleton symbols; mechanical follow-up (next
commit).
2026-04-26 21:45:02 +08:00
227 changed files with 6965 additions and 5578 deletions
+11 -2
View File
@@ -127,7 +127,7 @@ from app.gateway.app import app
from app.channels.service import start_channel_service from app.channels.service import start_channel_service
# App → Harness (allowed) # App → Harness (allowed)
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
# Harness → App (FORBIDDEN — enforced by test_harness_boundary.py) # Harness → App (FORBIDDEN — enforced by test_harness_boundary.py)
# from app.gateway.routers.uploads import ... # ← will fail CI # from app.gateway.routers.uploads import ... # ← will fail CI
@@ -182,7 +182,16 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc
**Config Versioning**: `config.example.yaml` has a `config_version` field. On startup, `AppConfig.from_file()` compares user version vs example version and emits a warning if outdated. Missing `config_version` = version 0. Run `make config-upgrade` to auto-merge missing fields. When changing the config schema, bump `config_version` in `config.example.yaml`. **Config Versioning**: `config.example.yaml` has a `config_version` field. On startup, `AppConfig.from_file()` compares user version vs example version and emits a warning if outdated. Missing `config_version` = version 0. Run `make config-upgrade` to auto-merge missing fields. When changing the config schema, bump `config_version` in `config.example.yaml`.
**Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart. **Config Lifecycle**: All config models are `frozen=True` (immutable after construction). `AppConfig.from_file()` is a pure function — no side effects, no process-global state. The resolved `AppConfig` is passed as an explicit parameter down every consumer lane:
- **Gateway**: `app.state.config` populated in lifespan; routers receive it via `Depends(get_config)` from `app/gateway/deps.py`.
- **Client**: `DeerFlowClient._app_config` captured in the constructor; every method reads `self._app_config`.
- **Agent run**: wrapped in `DeerFlowContext(app_config=…)` and injected via LangGraph `Runtime[DeerFlowContext].context`. Middleware and tools read `runtime.context.app_config` directly or via `resolve_context(runtime)`.
- **LangGraph Server bootstrap**: `make_lead_agent` (registered in `langgraph.json`) calls `AppConfig.from_file()` itself — the only place in production that loads from disk at agent-build time.
To update config at runtime (Gateway API mutations for MCP/Skills), write the new file and call `AppConfig.from_file()` to build a fresh snapshot, then swap `app.state.config`. No mtime detection, no auto-reload, no ambient ContextVar lookup (`AppConfig.current()` has been removed).
**DeerFlowContext**: Per-invocation typed context for the agent execution path, injected via LangGraph `Runtime[DeerFlowContext]`. Holds `app_config: AppConfig`, `thread_id: str`, `agent_name: str | None`. Gateway runtime and `DeerFlowClient` construct full `DeerFlowContext` at invoke time; the LangGraph Server boundary builds one inside `make_lead_agent`. Middleware and tools access context through `resolve_context(runtime)` which returns the typed `DeerFlowContext` — legacy dict/None shapes are rejected. Mutable runtime state (`sandbox_id`) flows through `ThreadState.sandbox`, not context.
Configuration priority: Configuration priority:
1. Explicit `config_path` argument 1. Explicit `config_path` argument
+3 -1
View File
@@ -375,7 +375,9 @@ class FeishuChannel(Channel):
virtual_path = f"{VIRTUAL_PATH_PREFIX}/uploads/{resolved_target.name}" virtual_path = f"{VIRTUAL_PATH_PREFIX}/uploads/{resolved_target.name}"
try: try:
sandbox_provider = get_sandbox_provider() from deerflow.config.app_config import AppConfig
sandbox_provider = get_sandbox_provider(AppConfig.from_file())
sandbox_id = sandbox_provider.acquire(thread_id) sandbox_id = sandbox_provider.acquire(thread_id)
if sandbox_id != "local": if sandbox_id != "local":
sandbox = sandbox_provider.get(sandbox_id) sandbox = sandbox_provider.get(sandbox_id)
-2
View File
@@ -17,8 +17,6 @@ from langgraph_sdk.errors import ConflictError
from app.channels.commands import KNOWN_CHANNEL_COMMANDS from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.store import ChannelStore from app.channels.store import ChannelStore
from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token
from app.gateway.internal_auth import create_internal_auth_headers
from deerflow.runtime.user_context import get_effective_user_id from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
+9 -9
View File
@@ -4,13 +4,16 @@ from __future__ import annotations
import logging import logging
import os import os
from typing import Any from typing import TYPE_CHECKING, Any
from app.channels.base import Channel from app.channels.base import Channel
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
from app.channels.message_bus import MessageBus from app.channels.message_bus import MessageBus
from app.channels.store import ChannelStore from app.channels.store import ChannelStore
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Channel name → import path for lazy loading # Channel name → import path for lazy loading
@@ -75,14 +78,11 @@ class ChannelService:
self._running = False self._running = False
@classmethod @classmethod
def from_app_config(cls) -> ChannelService: def from_app_config(cls, app_config: AppConfig) -> ChannelService:
"""Create a ChannelService from the application config.""" """Create a ChannelService from an explicit application config."""
from deerflow.config.app_config import get_app_config
config = get_app_config()
channels_config = {} channels_config = {}
# extra fields are allowed by AppConfig (extra="allow") # extra fields are allowed by AppConfig (extra="allow")
extra = config.model_extra or {} extra = app_config.model_extra or {}
if "channels" in extra: if "channels" in extra:
channels_config = extra["channels"] channels_config = extra["channels"]
return cls(channels_config=channels_config) return cls(channels_config=channels_config)
@@ -201,12 +201,12 @@ def get_channel_service() -> ChannelService | None:
return _channel_service return _channel_service
async def start_channel_service() -> ChannelService: async def start_channel_service(app_config: AppConfig) -> ChannelService:
"""Create and start the global ChannelService from app config.""" """Create and start the global ChannelService from app config."""
global _channel_service global _channel_service
if _channel_service is not None: if _channel_service is not None:
return _channel_service return _channel_service
_channel_service = ChannelService.from_app_config() _channel_service = ChannelService.from_app_config(app_config)
await _channel_service.start() await _channel_service.start()
return _channel_service return _channel_service
+11 -16
View File
@@ -28,7 +28,7 @@ from app.gateway.routers import (
threads, threads,
uploads, uploads,
) )
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import AppConfig
# Configure logging # Configure logging
logging.basicConfig( logging.basicConfig(
@@ -72,18 +72,7 @@ async def _ensure_admin_user(app: FastAPI) -> None:
from deerflow.persistence.engine import get_session_factory from deerflow.persistence.engine import get_session_factory
from deerflow.persistence.user.model import UserRow from deerflow.persistence.user.model import UserRow
try: provider = get_local_provider()
provider = get_local_provider()
except RuntimeError:
# Auth persistence may not be initialized in some test/boot paths.
# Skip admin migration work rather than failing gateway startup.
logger.warning("Auth persistence not ready; skipping admin bootstrap check")
return
sf = get_session_factory()
if sf is None:
return
admin_count = await provider.count_admin_users() admin_count = await provider.count_admin_users()
if admin_count == 0: if admin_count == 0:
@@ -95,6 +84,10 @@ async def _ensure_admin_user(app: FastAPI) -> None:
# Admin already exists — run orphan thread migration for any # Admin already exists — run orphan thread migration for any
# LangGraph thread metadata that pre-dates the auth module. # LangGraph thread metadata that pre-dates the auth module.
sf = get_session_factory()
if sf is None:
return
async with sf() as session: async with sf() as session:
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1) stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
row = (await session.execute(stmt)).scalar_one_or_none() row = (await session.execute(stmt)).scalar_one_or_none()
@@ -158,9 +151,11 @@ async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application lifespan handler.""" """Application lifespan handler."""
# Load config and check necessary environment variables at startup
try: try:
get_app_config() # ``app.state.config`` is the sole source of truth for
# ``Depends(get_config)``. Consumers that want AppConfig must receive
# it as an explicit parameter; there is no ambient singleton.
app.state.config = AppConfig.from_file()
logger.info("Configuration loaded successfully") logger.info("Configuration loaded successfully")
except Exception as e: except Exception as e:
error_msg = f"Failed to load configuration during gateway startup: {e}" error_msg = f"Failed to load configuration during gateway startup: {e}"
@@ -181,7 +176,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
try: try:
from app.channels.service import start_channel_service from app.channels.service import start_channel_service
channel_service = await start_channel_service() channel_service = await start_channel_service(app.state.config)
logger.info("Channel service started: %s", channel_service.get_status()) logger.info("Channel service started: %s", channel_service.get_status())
except Exception: except Exception:
logger.exception("No IM channels configured or channel service failed to start") logger.exception("No IM channels configured or channel service failed to start")
+2 -2
View File
@@ -12,12 +12,12 @@ class AuthProvider(ABC):
Returns User if authentication succeeds, None otherwise. Returns User if authentication succeeds, None otherwise.
""" """
raise NotImplementedError ...
@abstractmethod @abstractmethod
async def get_user(self, user_id: str) -> "User | None": async def get_user(self, user_id: str) -> "User | None":
"""Retrieve user by ID.""" """Retrieve user by ID."""
raise NotImplementedError ...
# Import User at runtime to avoid circular imports # Import User at runtime to avoid circular imports
@@ -35,7 +35,7 @@ class UserRepository(ABC):
Raises: Raises:
ValueError: If email already exists ValueError: If email already exists
""" """
raise NotImplementedError ...
@abstractmethod @abstractmethod
async def get_user_by_id(self, user_id: str) -> User | None: async def get_user_by_id(self, user_id: str) -> User | None:
@@ -47,7 +47,7 @@ class UserRepository(ABC):
Returns: Returns:
User if found, None otherwise User if found, None otherwise
""" """
raise NotImplementedError ...
@abstractmethod @abstractmethod
async def get_user_by_email(self, email: str) -> User | None: async def get_user_by_email(self, email: str) -> User | None:
@@ -59,7 +59,7 @@ class UserRepository(ABC):
Returns: Returns:
User if found, None otherwise User if found, None otherwise
""" """
raise NotImplementedError ...
@abstractmethod @abstractmethod
async def update_user(self, user: User) -> User: async def update_user(self, user: User) -> User:
@@ -76,17 +76,17 @@ class UserRepository(ABC):
a hard failure (not a no-op) so callers cannot mistake a a hard failure (not a no-op) so callers cannot mistake a
concurrent-delete race for a successful update. concurrent-delete race for a successful update.
""" """
raise NotImplementedError ...
@abstractmethod @abstractmethod
async def count_users(self) -> int: async def count_users(self) -> int:
"""Return total number of registered users.""" """Return total number of registered users."""
raise NotImplementedError ...
@abstractmethod @abstractmethod
async def count_admin_users(self) -> int: async def count_admin_users(self) -> int:
"""Return number of users with system_role == 'admin'.""" """Return number of users with system_role == 'admin'."""
raise NotImplementedError ...
@abstractmethod @abstractmethod
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None: async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
@@ -99,4 +99,4 @@ class UserRepository(ABC):
Returns: Returns:
User if found, None otherwise User if found, None otherwise
""" """
raise NotImplementedError ...
+3 -2
View File
@@ -25,14 +25,15 @@ from deerflow.persistence.user.model import UserRow
async def _run(email: str | None) -> int: async def _run(email: str | None) -> int:
from deerflow.config import get_app_config from deerflow.config import AppConfig
from deerflow.persistence.engine import ( from deerflow.persistence.engine import (
close_engine, close_engine,
get_session_factory, get_session_factory,
init_engine_from_config, init_engine_from_config,
) )
config = get_app_config() # CLI entry: load config explicitly at the top, pass down through the closure.
config = AppConfig.from_file()
await init_engine_from_config(config.database) await init_engine_from_config(config.database)
try: try:
sf = get_session_factory() sf = get_session_factory()
+5 -13
View File
@@ -18,7 +18,6 @@ from starlette.types import ASGIApp
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token
from deerflow.runtime.user_context import reset_current_user, set_current_user from deerflow.runtime.user_context import reset_current_user, set_current_user
# Paths that never require authentication. # Paths that never require authentication.
@@ -76,12 +75,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
if _is_public(request.url.path): if _is_public(request.url.path):
return await call_next(request) return await call_next(request)
internal_user = None
if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)):
internal_user = get_internal_user()
# Non-public path: require session cookie # Non-public path: require session cookie
if internal_user is None and not request.cookies.get("access_token"): if not request.cookies.get("access_token"):
return JSONResponse( return JSONResponse(
status_code=401, status_code=401,
content={ content={
@@ -105,13 +100,10 @@ class AuthMiddleware(BaseHTTPMiddleware):
# bubble up, so we catch and render it as JSONResponse here. # bubble up, so we catch and render it as JSONResponse here.
from app.gateway.deps import get_current_user_from_request from app.gateway.deps import get_current_user_from_request
if internal_user is not None: try:
user = internal_user user = await get_current_user_from_request(request)
else: except HTTPException as exc:
try: return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
user = await get_current_user_from_request(request)
except HTTPException as exc:
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
# Stamp both request.state.user (for the contextvar pattern) # Stamp both request.state.user (for the contextvar pattern)
# and request.state.auth (so @require_permission's "auth is # and request.state.auth (so @require_permission's "auth is
+2 -33
View File
@@ -30,9 +30,7 @@ Inspired by LangGraph Auth system: https://github.com/langchain-ai/langgraph/blo
from __future__ import annotations from __future__ import annotations
import functools import functools
import inspect
from collections.abc import Callable from collections.abc import Callable
from types import SimpleNamespace
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
@@ -119,15 +117,6 @@ _ALL_PERMISSIONS: list[str] = [
] ]
def _make_test_request_stub() -> Any:
"""Create a minimal request-like object for direct unit calls.
Used when decorated route handlers are invoked without FastAPI's
request injection. Includes fields accessed by auth helpers.
"""
return SimpleNamespace(state=SimpleNamespace(), cookies={}, _deerflow_test_bypass_auth=True)
async def _authenticate(request: Request) -> AuthContext: async def _authenticate(request: Request) -> AuthContext:
"""Authenticate request and return AuthContext. """Authenticate request and return AuthContext.
@@ -165,17 +154,7 @@ def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
async def wrapper(*args: Any, **kwargs: Any) -> Any: async def wrapper(*args: Any, **kwargs: Any) -> Any:
request = kwargs.get("request") request = kwargs.get("request")
if request is None: if request is None:
# Unit tests may call decorated handlers directly without a raise ValueError("require_auth decorator requires 'request' parameter")
# FastAPI Request object. Inject a minimal request stub when
# the wrapped function declares `request`.
if "request" in inspect.signature(func).parameters:
kwargs["request"] = _make_test_request_stub()
else:
raise ValueError("require_auth decorator requires 'request' parameter")
request = kwargs["request"]
if getattr(request, "_deerflow_test_bypass_auth", False):
return await func(*args, **kwargs)
# Authenticate and set context # Authenticate and set context
auth_context = await _authenticate(request) auth_context = await _authenticate(request)
@@ -231,17 +210,7 @@ def require_permission(
async def wrapper(*args: Any, **kwargs: Any) -> Any: async def wrapper(*args: Any, **kwargs: Any) -> Any:
request = kwargs.get("request") request = kwargs.get("request")
if request is None: if request is None:
# Unit tests may call decorated route handlers directly without raise ValueError("require_permission decorator requires 'request' parameter")
# constructing a FastAPI Request object. Inject a minimal stub
# when the wrapped function declares `request`.
if "request" in inspect.signature(func).parameters:
kwargs["request"] = _make_test_request_stub()
else:
return await func(*args, **kwargs)
request = kwargs["request"]
if getattr(request, "_deerflow_test_bypass_auth", False):
return await func(*args, **kwargs)
auth: AuthContext = getattr(request.state, "auth", None) auth: AuthContext = getattr(request.state, "auth", None)
if auth is None: if auth is None:
+38 -24
View File
@@ -10,15 +10,13 @@ from __future__ import annotations
from collections.abc import AsyncGenerator, Callable from collections.abc import AsyncGenerator, Callable
from contextlib import AsyncExitStack, asynccontextmanager from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, TypeVar, cast from typing import TYPE_CHECKING
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request
from langgraph.types import Checkpointer from langgraph.types import Checkpointer
from deerflow.persistence.feedback import FeedbackRepository from deerflow.config.app_config import AppConfig
from deerflow.runtime import RunContext, RunManager, StreamBridge from deerflow.runtime import RunContext, RunManager
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.runs.store.base import RunStore
if TYPE_CHECKING: if TYPE_CHECKING:
from app.gateway.auth.local_provider import LocalAuthProvider from app.gateway.auth.local_provider import LocalAuthProvider
@@ -26,7 +24,17 @@ if TYPE_CHECKING:
from deerflow.persistence.thread_meta.base import ThreadMetaStore from deerflow.persistence.thread_meta.base import ThreadMetaStore
T = TypeVar("T") def get_config(request: Request) -> AppConfig:
"""FastAPI dependency returning the app-scoped ``AppConfig``.
Reads from ``request.app.state.config`` which is set at startup
(``app.py`` lifespan) and swapped on config reload (``routers/mcp.py``,
``routers/skills.py``).
"""
cfg = getattr(request.app.state, "config", None)
if cfg is None:
raise HTTPException(status_code=503, detail="Configuration not available")
return cfg
@asynccontextmanager @asynccontextmanager
@@ -38,22 +46,24 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
async with langgraph_runtime(app): async with langgraph_runtime(app):
yield yield
""" """
from deerflow.config import get_app_config
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
from deerflow.runtime import make_store, make_stream_bridge from deerflow.runtime import make_store, make_stream_bridge
from deerflow.runtime.checkpointer.async_provider import make_checkpointer from deerflow.runtime.checkpointer.async_provider import make_checkpointer
from deerflow.runtime.events.store import make_run_event_store from deerflow.runtime.events.store import make_run_event_store
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge()) # app.state.config is populated earlier in lifespan(); thread it
# explicitly into every provider below.
config = app.state.config
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge(config))
# Initialize persistence engine BEFORE checkpointer so that # Initialize persistence engine BEFORE checkpointer so that
# auto-create-database logic runs first (postgres backend). # auto-create-database logic runs first (postgres backend).
config = get_app_config()
await init_engine_from_config(config.database) await init_engine_from_config(config.database)
app.state.checkpointer = await stack.enter_async_context(make_checkpointer()) app.state.checkpointer = await stack.enter_async_context(make_checkpointer(config))
app.state.store = await stack.enter_async_context(make_store()) app.state.store = await stack.enter_async_context(make_store(config))
# Initialize repositories — one get_session_factory() call for all. # Initialize repositories — one get_session_factory() call for all.
sf = get_session_factory() sf = get_session_factory()
@@ -91,25 +101,25 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _require(attr: str, label: str) -> Callable[[Request], T]: def _require(attr: str, label: str):
"""Create a FastAPI dependency that returns ``app.state.<attr>`` or 503.""" """Create a FastAPI dependency that returns ``app.state.<attr>`` or 503."""
def dep(request: Request) -> T: def dep(request: Request):
val = getattr(request.app.state, attr, None) val = getattr(request.app.state, attr, None)
if val is None: if val is None:
raise HTTPException(status_code=503, detail=f"{label} not available") raise HTTPException(status_code=503, detail=f"{label} not available")
return cast(T, val) return val
dep.__name__ = dep.__qualname__ = f"get_{attr}" dep.__name__ = dep.__qualname__ = f"get_{attr}"
return dep return dep
get_stream_bridge: Callable[[Request], StreamBridge] = _require("stream_bridge", "Stream bridge") get_stream_bridge = _require("stream_bridge", "Stream bridge")
get_run_manager: Callable[[Request], RunManager] = _require("run_manager", "Run manager") get_run_manager = _require("run_manager", "Run manager")
get_checkpointer: Callable[[Request], Checkpointer] = _require("checkpointer", "Checkpointer") get_checkpointer = _require("checkpointer", "Checkpointer")
get_run_event_store: Callable[[Request], RunEventStore] = _require("run_event_store", "Run event store") get_run_event_store = _require("run_event_store", "Run event store")
get_feedback_repo: Callable[[Request], FeedbackRepository] = _require("feedback_repo", "Feedback") get_feedback_repo = _require("feedback_repo", "Feedback")
get_run_store: Callable[[Request], RunStore] = _require("run_store", "Run store") get_run_store = _require("run_store", "Run store")
def get_store(request: Request): def get_store(request: Request):
@@ -128,19 +138,23 @@ def get_thread_store(request: Request) -> ThreadMetaStore:
def get_run_context(request: Request) -> RunContext: def get_run_context(request: Request) -> RunContext:
"""Build a :class:`RunContext` from ``app.state`` singletons. """Build a :class:`RunContext` from ``app.state`` singletons.
Returns a *base* context with infrastructure dependencies. Returns a *base* context with infrastructure dependencies. Callers that
need per-run fields (e.g. ``follow_up_to_run_id``) should use
``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it
to :func:`run_agent`.
""" """
from deerflow.config import get_app_config config = get_config(request)
return RunContext( return RunContext(
checkpointer=get_checkpointer(request), checkpointer=get_checkpointer(request),
store=get_store(request), store=get_store(request),
event_store=get_run_event_store(request), event_store=get_run_event_store(request),
run_events_config=getattr(get_app_config(), "run_events", None), run_events_config=getattr(config, "run_events", None),
thread_store=get_thread_store(request), thread_store=get_thread_store(request),
app_config=config,
) )
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Auth helpers (used by authz.py and auth middleware) # Auth helpers (used by authz.py and auth middleware)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
+20 -19
View File
@@ -5,11 +5,12 @@ import re
import shutil import shutil
import yaml import yaml
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from deerflow.config.agents_api_config import get_agents_api_config from app.gateway.deps import get_config
from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul from deerflow.config.agents_config import AgentConfig, list_custom_agents, load_agent_config, load_agent_soul
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -77,9 +78,9 @@ def _normalize_agent_name(name: str) -> str:
return name.lower() return name.lower()
def _require_agents_api_enabled() -> None: def _require_agents_api_enabled(app_config: AppConfig) -> None:
"""Reject access unless the custom-agent management API is explicitly enabled.""" """Reject access unless the custom-agent management API is explicitly enabled."""
if not get_agents_api_config().enabled: if not app_config.agents_api.enabled:
raise HTTPException( raise HTTPException(
status_code=403, status_code=403,
detail=("Custom-agent management API is disabled. Set agents_api.enabled=true to expose agent and user-profile routes over HTTP."), detail=("Custom-agent management API is disabled. Set agents_api.enabled=true to expose agent and user-profile routes over HTTP."),
@@ -108,13 +109,13 @@ def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False
summary="List Custom Agents", summary="List Custom Agents",
description="List all custom agents available in the agents directory, including their soul content.", description="List all custom agents available in the agents directory, including their soul content.",
) )
async def list_agents() -> AgentsListResponse: async def list_agents(app_config: AppConfig = Depends(get_config)) -> AgentsListResponse:
"""List all custom agents. """List all custom agents.
Returns: Returns:
List of all custom agents with their metadata and soul content. List of all custom agents with their metadata and soul content.
""" """
_require_agents_api_enabled() _require_agents_api_enabled(app_config)
try: try:
agents = list_custom_agents() agents = list_custom_agents()
@@ -141,7 +142,7 @@ async def check_agent_name(name: str) -> dict:
Raises: Raises:
HTTPException: 422 if the name is invalid. HTTPException: 422 if the name is invalid.
""" """
_require_agents_api_enabled() _require_agents_api_enabled(app_config)
_validate_agent_name(name) _validate_agent_name(name)
normalized = _normalize_agent_name(name) normalized = _normalize_agent_name(name)
available = not get_paths().agent_dir(normalized).exists() available = not get_paths().agent_dir(normalized).exists()
@@ -154,7 +155,7 @@ async def check_agent_name(name: str) -> dict:
summary="Get Custom Agent", summary="Get Custom Agent",
description="Retrieve details and SOUL.md content for a specific custom agent.", description="Retrieve details and SOUL.md content for a specific custom agent.",
) )
async def get_agent(name: str) -> AgentResponse: async def get_agent(name: str, app_config: AppConfig = Depends(get_config)) -> AgentResponse:
"""Get a specific custom agent by name. """Get a specific custom agent by name.
Args: Args:
@@ -166,7 +167,7 @@ async def get_agent(name: str) -> AgentResponse:
Raises: Raises:
HTTPException: 404 if agent not found. HTTPException: 404 if agent not found.
""" """
_require_agents_api_enabled() _require_agents_api_enabled(app_config)
_validate_agent_name(name) _validate_agent_name(name)
name = _normalize_agent_name(name) name = _normalize_agent_name(name)
@@ -187,7 +188,7 @@ async def get_agent(name: str) -> AgentResponse:
summary="Create Custom Agent", summary="Create Custom Agent",
description="Create a new custom agent with its config and SOUL.md.", description="Create a new custom agent with its config and SOUL.md.",
) )
async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse: async def create_agent_endpoint(request: AgentCreateRequest, app_config: AppConfig = Depends(get_config)) -> AgentResponse:
"""Create a new custom agent. """Create a new custom agent.
Args: Args:
@@ -199,7 +200,7 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
Raises: Raises:
HTTPException: 409 if agent already exists, 422 if name is invalid. HTTPException: 409 if agent already exists, 422 if name is invalid.
""" """
_require_agents_api_enabled() _require_agents_api_enabled(app_config)
_validate_agent_name(request.name) _validate_agent_name(request.name)
normalized_name = _normalize_agent_name(request.name) normalized_name = _normalize_agent_name(request.name)
@@ -251,7 +252,7 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
summary="Update Custom Agent", summary="Update Custom Agent",
description="Update an existing custom agent's config and/or SOUL.md.", description="Update an existing custom agent's config and/or SOUL.md.",
) )
async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse: async def update_agent(name: str, request: AgentUpdateRequest, app_config: AppConfig = Depends(get_config)) -> AgentResponse:
"""Update an existing custom agent. """Update an existing custom agent.
Args: Args:
@@ -264,7 +265,7 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
Raises: Raises:
HTTPException: 404 if agent not found. HTTPException: 404 if agent not found.
""" """
_require_agents_api_enabled() _require_agents_api_enabled(app_config)
_validate_agent_name(name) _validate_agent_name(name)
name = _normalize_agent_name(name) name = _normalize_agent_name(name)
@@ -342,13 +343,13 @@ class UserProfileUpdateRequest(BaseModel):
summary="Get User Profile", summary="Get User Profile",
description="Read the global USER.md file that is injected into all custom agents.", description="Read the global USER.md file that is injected into all custom agents.",
) )
async def get_user_profile() -> UserProfileResponse: async def get_user_profile(app_config: AppConfig = Depends(get_config)) -> UserProfileResponse:
"""Return the current USER.md content. """Return the current USER.md content.
Returns: Returns:
UserProfileResponse with content=None if USER.md does not exist yet. UserProfileResponse with content=None if USER.md does not exist yet.
""" """
_require_agents_api_enabled() _require_agents_api_enabled(app_config)
try: try:
user_md_path = get_paths().user_md_file user_md_path = get_paths().user_md_file
@@ -367,7 +368,7 @@ async def get_user_profile() -> UserProfileResponse:
summary="Update User Profile", summary="Update User Profile",
description="Write the global USER.md file that is injected into all custom agents.", description="Write the global USER.md file that is injected into all custom agents.",
) )
async def update_user_profile(request: UserProfileUpdateRequest) -> UserProfileResponse: async def update_user_profile(request: UserProfileUpdateRequest, app_config: AppConfig = Depends(get_config)) -> UserProfileResponse:
"""Create or overwrite the global USER.md. """Create or overwrite the global USER.md.
Args: Args:
@@ -376,7 +377,7 @@ async def update_user_profile(request: UserProfileUpdateRequest) -> UserProfileR
Returns: Returns:
UserProfileResponse with the saved content. UserProfileResponse with the saved content.
""" """
_require_agents_api_enabled() _require_agents_api_enabled(app_config)
try: try:
paths = get_paths() paths = get_paths()
@@ -395,7 +396,7 @@ async def update_user_profile(request: UserProfileUpdateRequest) -> UserProfileR
summary="Delete Custom Agent", summary="Delete Custom Agent",
description="Delete a custom agent and all its files (config, SOUL.md, memory).", description="Delete a custom agent and all its files (config, SOUL.md, memory).",
) )
async def delete_agent(name: str) -> None: async def delete_agent(name: str, app_config: AppConfig = Depends(get_config)) -> None:
"""Delete a custom agent. """Delete a custom agent.
Args: Args:
@@ -404,7 +405,7 @@ async def delete_agent(name: str) -> None:
Raises: Raises:
HTTPException: 404 if agent not found. HTTPException: 404 if agent not found.
""" """
_require_agents_api_enabled() _require_agents_api_enabled(app_config)
_validate_agent_name(name) _validate_agent_name(name)
name = _normalize_agent_name(name) name = _normalize_agent_name(name)
+20 -12
View File
@@ -3,10 +3,12 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Literal from typing import Literal
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from deerflow.config.extensions_config import ExtensionsConfig, get_extensions_config, reload_extensions_config from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["mcp"]) router = APIRouter(prefix="/api", tags=["mcp"])
@@ -69,7 +71,7 @@ class McpConfigUpdateRequest(BaseModel):
summary="Get MCP Configuration", summary="Get MCP Configuration",
description="Retrieve the current Model Context Protocol (MCP) server configurations.", description="Retrieve the current Model Context Protocol (MCP) server configurations.",
) )
async def get_mcp_configuration() -> McpConfigResponse: async def get_mcp_configuration(config: AppConfig = Depends(get_config)) -> McpConfigResponse:
"""Get the current MCP configuration. """Get the current MCP configuration.
Returns: Returns:
@@ -90,9 +92,9 @@ async def get_mcp_configuration() -> McpConfigResponse:
} }
``` ```
""" """
config = get_extensions_config() ext = config.extensions
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in config.mcp_servers.items()}) return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in ext.mcp_servers.items()})
@router.put( @router.put(
@@ -101,7 +103,11 @@ async def get_mcp_configuration() -> McpConfigResponse:
summary="Update MCP Configuration", summary="Update MCP Configuration",
description="Update Model Context Protocol (MCP) server configurations and save to file.", description="Update Model Context Protocol (MCP) server configurations and save to file.",
) )
async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfigResponse: async def update_mcp_configuration(
request: McpConfigUpdateRequest,
http_request: Request,
config: AppConfig = Depends(get_config),
) -> McpConfigResponse:
"""Update the MCP configuration. """Update the MCP configuration.
This will: This will:
@@ -142,13 +148,13 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
config_path = Path.cwd().parent / "extensions_config.json" config_path = Path.cwd().parent / "extensions_config.json"
logger.info(f"No existing extensions config found. Creating new config at: {config_path}") logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
# Load current config to preserve skills configuration # Use injected config to preserve skills configuration
current_config = get_extensions_config() current_ext = config.extensions
# Convert request to dict format for JSON serialization # Convert request to dict format for JSON serialization
config_data = { config_data = {
"mcpServers": {name: server.model_dump() for name, server in request.mcp_servers.items()}, "mcpServers": {name: server.model_dump() for name, server in request.mcp_servers.items()},
"skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()}, "skills": {name: {"enabled": skill.enabled} for name, skill in current_ext.skills.items()},
} }
# Write the configuration to file # Write the configuration to file
@@ -160,9 +166,11 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig
# NOTE: No need to reload/reset cache here - LangGraph Server (separate process) # NOTE: No need to reload/reset cache here - LangGraph Server (separate process)
# will detect config file changes via mtime and reinitialize MCP tools automatically # will detect config file changes via mtime and reinitialize MCP tools automatically
# Reload the configuration and update the global cache # Reload the configuration and swap ``app.state.config`` so subsequent
reloaded_config = reload_extensions_config() # ``Depends(get_config)`` calls see the refreshed value.
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded_config.mcp_servers.items()}) reloaded = AppConfig.from_file()
http_request.app.state.config = reloaded
return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded.extensions.mcp_servers.items()})
except Exception as e: except Exception as e:
logger.error(f"Failed to update MCP configuration: {e}", exc_info=True) logger.error(f"Failed to update MCP configuration: {e}", exc_info=True)
+28 -21
View File
@@ -1,8 +1,9 @@
"""Memory API router for retrieving and managing global memory data.""" """Memory API router for retrieving and managing global memory data."""
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.gateway.deps import get_config
from deerflow.agents.memory.updater import ( from deerflow.agents.memory.updater import (
clear_memory_data, clear_memory_data,
create_memory_fact, create_memory_fact,
@@ -12,7 +13,7 @@ from deerflow.agents.memory.updater import (
reload_memory_data, reload_memory_data,
update_memory_fact, update_memory_fact,
) )
from deerflow.config.memory_config import get_memory_config from deerflow.config.app_config import AppConfig
from deerflow.runtime.user_context import get_effective_user_id from deerflow.runtime.user_context import get_effective_user_id
router = APIRouter(prefix="/api", tags=["memory"]) router = APIRouter(prefix="/api", tags=["memory"])
@@ -114,7 +115,7 @@ class MemoryStatusResponse(BaseModel):
summary="Get Memory Data", summary="Get Memory Data",
description="Retrieve the current global memory data including user context, history, and facts.", description="Retrieve the current global memory data including user context, history, and facts.",
) )
async def get_memory() -> MemoryResponse: async def get_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Get the current global memory data. """Get the current global memory data.
Returns: Returns:
@@ -148,7 +149,7 @@ async def get_memory() -> MemoryResponse:
} }
``` ```
""" """
memory_data = get_memory_data(user_id=get_effective_user_id()) memory_data = get_memory_data(app_config.memory, user_id=get_effective_user_id())
return MemoryResponse(**memory_data) return MemoryResponse(**memory_data)
@@ -159,7 +160,7 @@ async def get_memory() -> MemoryResponse:
summary="Reload Memory Data", summary="Reload Memory Data",
description="Reload memory data from the storage file, refreshing the in-memory cache.", description="Reload memory data from the storage file, refreshing the in-memory cache.",
) )
async def reload_memory() -> MemoryResponse: async def reload_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Reload memory data from file. """Reload memory data from file.
This forces a reload of the memory data from the storage file, This forces a reload of the memory data from the storage file,
@@ -168,7 +169,7 @@ async def reload_memory() -> MemoryResponse:
Returns: Returns:
The reloaded memory data. The reloaded memory data.
""" """
memory_data = reload_memory_data(user_id=get_effective_user_id()) memory_data = reload_memory_data(app_config.memory, user_id=get_effective_user_id())
return MemoryResponse(**memory_data) return MemoryResponse(**memory_data)
@@ -179,10 +180,10 @@ async def reload_memory() -> MemoryResponse:
summary="Clear All Memory Data", summary="Clear All Memory Data",
description="Delete all saved memory data and reset the memory structure to an empty state.", description="Delete all saved memory data and reset the memory structure to an empty state.",
) )
async def clear_memory() -> MemoryResponse: async def clear_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Clear all persisted memory data.""" """Clear all persisted memory data."""
try: try:
memory_data = clear_memory_data(user_id=get_effective_user_id()) memory_data = clear_memory_data(app_config.memory, user_id=get_effective_user_id())
except OSError as exc: except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
@@ -196,10 +197,11 @@ async def clear_memory() -> MemoryResponse:
summary="Create Memory Fact", summary="Create Memory Fact",
description="Create a single saved memory fact manually.", description="Create a single saved memory fact manually.",
) )
async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryResponse: async def create_memory_fact_endpoint(request: FactCreateRequest, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Create a single fact manually.""" """Create a single fact manually."""
try: try:
memory_data = create_memory_fact( memory_data = create_memory_fact(
app_config.memory,
content=request.content, content=request.content,
category=request.category, category=request.category,
confidence=request.confidence, confidence=request.confidence,
@@ -220,10 +222,10 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
summary="Delete Memory Fact", summary="Delete Memory Fact",
description="Delete a single saved memory fact by its fact id.", description="Delete a single saved memory fact by its fact id.",
) )
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse: async def delete_memory_fact_endpoint(fact_id: str, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Delete a single fact from memory by fact id.""" """Delete a single fact from memory by fact id."""
try: try:
memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id()) memory_data = delete_memory_fact(app_config.memory, fact_id, user_id=get_effective_user_id())
except KeyError as exc: except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
except OSError as exc: except OSError as exc:
@@ -239,10 +241,11 @@ async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
summary="Patch Memory Fact", summary="Patch Memory Fact",
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.", description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
) )
async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -> MemoryResponse: async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Partially update a single fact manually.""" """Partially update a single fact manually."""
try: try:
memory_data = update_memory_fact( memory_data = update_memory_fact(
app_config.memory,
fact_id=fact_id, fact_id=fact_id,
content=request.content, content=request.content,
category=request.category, category=request.category,
@@ -266,9 +269,9 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
summary="Export Memory Data", summary="Export Memory Data",
description="Export the current global memory data as JSON for backup or transfer.", description="Export the current global memory data as JSON for backup or transfer.",
) )
async def export_memory() -> MemoryResponse: async def export_memory(app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Export the current memory data.""" """Export the current memory data."""
memory_data = get_memory_data(user_id=get_effective_user_id()) memory_data = get_memory_data(app_config.memory, user_id=get_effective_user_id())
return MemoryResponse(**memory_data) return MemoryResponse(**memory_data)
@@ -279,10 +282,10 @@ async def export_memory() -> MemoryResponse:
summary="Import Memory Data", summary="Import Memory Data",
description="Import and overwrite the current global memory data from a JSON payload.", description="Import and overwrite the current global memory data from a JSON payload.",
) )
async def import_memory(request: MemoryResponse) -> MemoryResponse: async def import_memory(request: MemoryResponse, app_config: AppConfig = Depends(get_config)) -> MemoryResponse:
"""Import and persist memory data.""" """Import and persist memory data."""
try: try:
memory_data = import_memory_data(request.model_dump(), user_id=get_effective_user_id()) memory_data = import_memory_data(app_config.memory, request.model_dump(), user_id=get_effective_user_id())
except OSError as exc: except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
@@ -295,7 +298,9 @@ async def import_memory(request: MemoryResponse) -> MemoryResponse:
summary="Get Memory Configuration", summary="Get Memory Configuration",
description="Retrieve the current memory system configuration.", description="Retrieve the current memory system configuration.",
) )
async def get_memory_config_endpoint() -> MemoryConfigResponse: async def get_memory_config_endpoint(
app_config: AppConfig = Depends(get_config),
) -> MemoryConfigResponse:
"""Get the memory system configuration. """Get the memory system configuration.
Returns: Returns:
@@ -314,7 +319,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
} }
``` ```
""" """
config = get_memory_config() config = app_config.memory
return MemoryConfigResponse( return MemoryConfigResponse(
enabled=config.enabled, enabled=config.enabled,
storage_path=config.storage_path, storage_path=config.storage_path,
@@ -333,14 +338,16 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
summary="Get Memory Status", summary="Get Memory Status",
description="Retrieve both memory configuration and current data in a single request.", description="Retrieve both memory configuration and current data in a single request.",
) )
async def get_memory_status() -> MemoryStatusResponse: async def get_memory_status(
app_config: AppConfig = Depends(get_config),
) -> MemoryStatusResponse:
"""Get the memory system status including configuration and data. """Get the memory system status including configuration and data.
Returns: Returns:
Combined memory configuration and current data. Combined memory configuration and current data.
""" """
config = get_memory_config() config = app_config.memory
memory_data = get_memory_data(user_id=get_effective_user_id()) memory_data = get_memory_data(config, user_id=get_effective_user_id())
return MemoryStatusResponse( return MemoryStatusResponse(
config=MemoryConfigResponse( config=MemoryConfigResponse(
+5 -6
View File
@@ -1,7 +1,8 @@
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from deerflow.config import get_app_config from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
router = APIRouter(prefix="/api", tags=["models"]) router = APIRouter(prefix="/api", tags=["models"])
@@ -36,7 +37,7 @@ class ModelsListResponse(BaseModel):
summary="List All Models", summary="List All Models",
description="Retrieve a list of all available AI models configured in the system.", description="Retrieve a list of all available AI models configured in the system.",
) )
async def list_models() -> ModelsListResponse: async def list_models(config: AppConfig = Depends(get_config)) -> ModelsListResponse:
"""List all available models from configuration. """List all available models from configuration.
Returns model information suitable for frontend display, Returns model information suitable for frontend display,
@@ -72,7 +73,6 @@ async def list_models() -> ModelsListResponse:
} }
``` ```
""" """
config = get_app_config()
models = [ models = [
ModelResponse( ModelResponse(
name=model.name, name=model.name,
@@ -96,7 +96,7 @@ async def list_models() -> ModelsListResponse:
summary="Get Model Details", summary="Get Model Details",
description="Retrieve detailed information about a specific AI model by its name.", description="Retrieve detailed information about a specific AI model by its name.",
) )
async def get_model(model_name: str) -> ModelResponse: async def get_model(model_name: str, config: AppConfig = Depends(get_config)) -> ModelResponse:
"""Get a specific model by name. """Get a specific model by name.
Args: Args:
@@ -118,7 +118,6 @@ async def get_model(model_name: str) -> ModelResponse:
} }
``` ```
""" """
config = get_app_config()
model = config.get_model_config(model_name) model = config.get_model_config(model_name)
if model is None: if model is None:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
+1 -2
View File
@@ -123,8 +123,7 @@ async def run_messages(
run = await _resolve_run(run_id, request) run = await _resolve_run(run_id, request)
event_store = get_run_event_store(request) event_store = get_run_event_store(request)
rows = await event_store.list_messages_by_run( rows = await event_store.list_messages_by_run(
run["thread_id"], run["thread_id"], run_id,
run_id,
limit=limit + 1, limit=limit + 1,
before_seq=before_seq, before_seq=before_seq,
after_seq=after_seq, after_seq=after_seq,
+69 -45
View File
@@ -4,12 +4,14 @@ import logging
import shutil import shutil
from pathlib import Path from pathlib import Path
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.gateway.deps import get_config
from app.gateway.path_utils import resolve_thread_virtual_path from app.gateway.path_utils import resolve_thread_virtual_path
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.skills import Skill, load_skills from deerflow.skills import Skill, load_skills
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
from deerflow.skills.manager import ( from deerflow.skills.manager import (
@@ -101,9 +103,9 @@ def _skill_to_response(skill: Skill) -> SkillResponse:
summary="List All Skills", summary="List All Skills",
description="Retrieve a list of all available skills from both public and custom directories.", description="Retrieve a list of all available skills from both public and custom directories.",
) )
async def list_skills() -> SkillsListResponse: async def list_skills(app_config: AppConfig = Depends(get_config)) -> SkillsListResponse:
try: try:
skills = load_skills(enabled_only=False) skills = load_skills(app_config, enabled_only=False)
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills]) return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e: except Exception as e:
logger.error(f"Failed to load skills: {e}", exc_info=True) logger.error(f"Failed to load skills: {e}", exc_info=True)
@@ -116,11 +118,11 @@ async def list_skills() -> SkillsListResponse:
summary="Install Skill", summary="Install Skill",
description="Install a skill from a .skill file (ZIP archive) located in the thread's user-data directory.", description="Install a skill from a .skill file (ZIP archive) located in the thread's user-data directory.",
) )
async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse: async def install_skill(request: SkillInstallRequest, app_config: AppConfig = Depends(get_config)) -> SkillInstallResponse:
try: try:
skill_file_path = resolve_thread_virtual_path(request.thread_id, request.path) skill_file_path = resolve_thread_virtual_path(request.thread_id, request.path)
result = install_skill_from_archive(skill_file_path) result = install_skill_from_archive(skill_file_path)
await refresh_skills_system_prompt_cache_async() await refresh_skills_system_prompt_cache_async(app_config)
return SkillInstallResponse(**result) return SkillInstallResponse(**result)
except FileNotFoundError as e: except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@@ -136,9 +138,9 @@ async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
@router.get("/skills/custom", response_model=SkillsListResponse, summary="List Custom Skills") @router.get("/skills/custom", response_model=SkillsListResponse, summary="List Custom Skills")
async def list_custom_skills() -> SkillsListResponse: async def list_custom_skills(app_config: AppConfig = Depends(get_config)) -> SkillsListResponse:
try: try:
skills = [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"] skills = [skill for skill in load_skills(app_config, enabled_only=False) if skill.category == "custom"]
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills]) return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e: except Exception as e:
logger.error("Failed to list custom skills: %s", e, exc_info=True) logger.error("Failed to list custom skills: %s", e, exc_info=True)
@@ -146,13 +148,13 @@ async def list_custom_skills() -> SkillsListResponse:
@router.get("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Get Custom Skill Content") @router.get("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Get Custom Skill Content")
async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse: async def get_custom_skill(skill_name: str, app_config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
try: try:
skills = load_skills(enabled_only=False) skills = load_skills(app_config, enabled_only=False)
skill = next((s for s in skills if s.name == skill_name and s.category == "custom"), None) skill = next((s for s in skills if s.name == skill_name and s.category == "custom"), None)
if skill is None: if skill is None:
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found") raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name)) return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name, app_config))
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@@ -161,14 +163,18 @@ async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
@router.put("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Edit Custom Skill") @router.put("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Edit Custom Skill")
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest) -> CustomSkillContentResponse: async def update_custom_skill(
skill_name: str,
request: CustomSkillUpdateRequest,
app_config: AppConfig = Depends(get_config),
) -> CustomSkillContentResponse:
try: try:
ensure_custom_skill_is_editable(skill_name) ensure_custom_skill_is_editable(skill_name, app_config)
validate_skill_markdown_content(skill_name, request.content) validate_skill_markdown_content(skill_name, request.content)
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/SKILL.md") scan = await scan_skill_content(app_config, request.content, executable=False, location=f"{skill_name}/SKILL.md")
if scan.decision == "block": if scan.decision == "block":
raise HTTPException(status_code=400, detail=f"Security scan blocked the edit: {scan.reason}") raise HTTPException(status_code=400, detail=f"Security scan blocked the edit: {scan.reason}")
skill_file = get_custom_skill_dir(skill_name) / "SKILL.md" skill_file = get_custom_skill_dir(skill_name, app_config) / "SKILL.md"
prev_content = skill_file.read_text(encoding="utf-8") prev_content = skill_file.read_text(encoding="utf-8")
atomic_write(skill_file, request.content) atomic_write(skill_file, request.content)
append_history( append_history(
@@ -182,9 +188,10 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
"new_content": request.content, "new_content": request.content,
"scanner": {"decision": scan.decision, "reason": scan.reason}, "scanner": {"decision": scan.decision, "reason": scan.reason},
}, },
app_config,
) )
await refresh_skills_system_prompt_cache_async() await refresh_skills_system_prompt_cache_async(app_config)
return await get_custom_skill(skill_name) return await get_custom_skill(skill_name, app_config)
except HTTPException: except HTTPException:
raise raise
except FileNotFoundError as e: except FileNotFoundError as e:
@@ -197,11 +204,11 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
@router.delete("/skills/custom/{skill_name}", summary="Delete Custom Skill") @router.delete("/skills/custom/{skill_name}", summary="Delete Custom Skill")
async def delete_custom_skill(skill_name: str) -> dict[str, bool]: async def delete_custom_skill(skill_name: str, app_config: AppConfig = Depends(get_config)) -> dict[str, bool]:
try: try:
ensure_custom_skill_is_editable(skill_name) ensure_custom_skill_is_editable(skill_name, app_config)
skill_dir = get_custom_skill_dir(skill_name) skill_dir = get_custom_skill_dir(skill_name, app_config)
prev_content = read_custom_skill_content(skill_name) prev_content = read_custom_skill_content(skill_name, app_config)
try: try:
append_history( append_history(
skill_name, skill_name,
@@ -214,13 +221,14 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
"new_content": None, "new_content": None,
"scanner": {"decision": "allow", "reason": "Deletion requested."}, "scanner": {"decision": "allow", "reason": "Deletion requested."},
}, },
app_config,
) )
except OSError as e: except OSError as e:
if not isinstance(e, PermissionError) and e.errno not in {errno.EACCES, errno.EPERM, errno.EROFS}: if not isinstance(e, PermissionError) and e.errno not in {errno.EACCES, errno.EPERM, errno.EROFS}:
raise raise
logger.warning("Skipping delete history write for custom skill %s due to readonly/permission failure; continuing with skill directory removal: %s", skill_name, e) logger.warning("Skipping delete history write for custom skill %s due to readonly/permission failure; continuing with skill directory removal: %s", skill_name, e)
shutil.rmtree(skill_dir) shutil.rmtree(skill_dir)
await refresh_skills_system_prompt_cache_async() await refresh_skills_system_prompt_cache_async(app_config)
return {"success": True} return {"success": True}
except FileNotFoundError as e: except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@@ -232,11 +240,11 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
@router.get("/skills/custom/{skill_name}/history", response_model=CustomSkillHistoryResponse, summary="Get Custom Skill History") @router.get("/skills/custom/{skill_name}/history", response_model=CustomSkillHistoryResponse, summary="Get Custom Skill History")
async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryResponse: async def get_custom_skill_history(skill_name: str, app_config: AppConfig = Depends(get_config)) -> CustomSkillHistoryResponse:
try: try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists(): if not custom_skill_exists(skill_name, app_config) and not get_skill_history_file(skill_name, app_config).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found") raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillHistoryResponse(history=read_history(skill_name)) return CustomSkillHistoryResponse(history=read_history(skill_name, app_config))
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@@ -245,11 +253,15 @@ async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryRespons
@router.post("/skills/custom/{skill_name}/rollback", response_model=CustomSkillContentResponse, summary="Rollback Custom Skill") @router.post("/skills/custom/{skill_name}/rollback", response_model=CustomSkillContentResponse, summary="Rollback Custom Skill")
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest) -> CustomSkillContentResponse: async def rollback_custom_skill(
skill_name: str,
request: SkillRollbackRequest,
app_config: AppConfig = Depends(get_config),
) -> CustomSkillContentResponse:
try: try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists(): if not custom_skill_exists(skill_name, app_config) and not get_skill_history_file(skill_name, app_config).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found") raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
history = read_history(skill_name) history = read_history(skill_name, app_config)
if not history: if not history:
raise HTTPException(status_code=400, detail=f"Custom skill '{skill_name}' has no history") raise HTTPException(status_code=400, detail=f"Custom skill '{skill_name}' has no history")
record = history[request.history_index] record = history[request.history_index]
@@ -257,8 +269,8 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
if target_content is None: if target_content is None:
raise HTTPException(status_code=400, detail="Selected history entry has no previous content to roll back to") raise HTTPException(status_code=400, detail="Selected history entry has no previous content to roll back to")
validate_skill_markdown_content(skill_name, target_content) validate_skill_markdown_content(skill_name, target_content)
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/SKILL.md") scan = await scan_skill_content(app_config, target_content, executable=False, location=f"{skill_name}/SKILL.md")
skill_file = get_custom_skill_file(skill_name) skill_file = get_custom_skill_file(skill_name, app_config)
current_content = skill_file.read_text(encoding="utf-8") if skill_file.exists() else None current_content = skill_file.read_text(encoding="utf-8") if skill_file.exists() else None
history_entry = { history_entry = {
"action": "rollback", "action": "rollback",
@@ -271,12 +283,12 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
"scanner": {"decision": scan.decision, "reason": scan.reason}, "scanner": {"decision": scan.decision, "reason": scan.reason},
} }
if scan.decision == "block": if scan.decision == "block":
append_history(skill_name, history_entry) append_history(skill_name, history_entry, app_config)
raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}") raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}")
atomic_write(skill_file, target_content) atomic_write(skill_file, target_content)
append_history(skill_name, history_entry) append_history(skill_name, history_entry, app_config)
await refresh_skills_system_prompt_cache_async() await refresh_skills_system_prompt_cache_async(app_config)
return await get_custom_skill(skill_name) return await get_custom_skill(skill_name, app_config)
except HTTPException: except HTTPException:
raise raise
except IndexError: except IndexError:
@@ -296,9 +308,9 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
summary="Get Skill Details", summary="Get Skill Details",
description="Retrieve detailed information about a specific skill by its name.", description="Retrieve detailed information about a specific skill by its name.",
) )
async def get_skill(skill_name: str) -> SkillResponse: async def get_skill(skill_name: str, app_config: AppConfig = Depends(get_config)) -> SkillResponse:
try: try:
skills = load_skills(enabled_only=False) skills = load_skills(app_config, enabled_only=False)
skill = next((s for s in skills if s.name == skill_name), None) skill = next((s for s in skills if s.name == skill_name), None)
if skill is None: if skill is None:
@@ -318,9 +330,14 @@ async def get_skill(skill_name: str) -> SkillResponse:
summary="Update Skill", summary="Update Skill",
description="Update a skill's enabled status by modifying the extensions_config.json file.", description="Update a skill's enabled status by modifying the extensions_config.json file.",
) )
async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillResponse: async def update_skill(
skill_name: str,
request: SkillUpdateRequest,
http_request: Request,
app_config: AppConfig = Depends(get_config),
) -> SkillResponse:
try: try:
skills = load_skills(enabled_only=False) skills = load_skills(app_config, enabled_only=False)
skill = next((s for s in skills if s.name == skill_name), None) skill = next((s for s in skills if s.name == skill_name), None)
if skill is None: if skill is None:
@@ -331,22 +348,29 @@ async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillRes
config_path = Path.cwd().parent / "extensions_config.json" config_path = Path.cwd().parent / "extensions_config.json"
logger.info(f"No existing extensions config found. Creating new config at: {config_path}") logger.info(f"No existing extensions config found. Creating new config at: {config_path}")
extensions_config = get_extensions_config() # Do not mutate the frozen AppConfig in place. Compose the new skills
extensions_config.skills[skill_name] = SkillStateConfig(enabled=request.enabled) # state in a fresh dict, write to disk, and reload AppConfig below so
# every subsequent Depends(get_config) sees the refreshed snapshot.
ext = app_config.extensions
updated_skills = {name: {"enabled": skill_config.enabled} for name, skill_config in ext.skills.items()}
updated_skills[skill_name] = {"enabled": request.enabled}
config_data = { config_data = {
"mcpServers": {name: server.model_dump() for name, server in extensions_config.mcp_servers.items()}, "mcpServers": {name: server.model_dump() for name, server in ext.mcp_servers.items()},
"skills": {name: {"enabled": skill_config.enabled} for name, skill_config in extensions_config.skills.items()}, "skills": updated_skills,
} }
with open(config_path, "w", encoding="utf-8") as f: with open(config_path, "w", encoding="utf-8") as f:
json.dump(config_data, f, indent=2) json.dump(config_data, f, indent=2)
logger.info(f"Skills configuration updated and saved to: {config_path}") logger.info(f"Skills configuration updated and saved to: {config_path}")
reload_extensions_config() # Reload AppConfig and swap ``app.state.config`` so subsequent
await refresh_skills_system_prompt_cache_async() # ``Depends(get_config)`` sees the refreshed value.
reloaded = AppConfig.from_file()
http_request.app.state.config = reloaded
await refresh_skills_system_prompt_cache_async(reloaded)
skills = load_skills(enabled_only=False) skills = load_skills(reloaded, enabled_only=False)
updated_skill = next((s for s in skills if s.name == skill_name), None) updated_skill = next((s for s in skills if s.name == skill_name), None)
if updated_skill is None: if updated_skill is None:
+5 -3
View File
@@ -1,11 +1,13 @@
import json import json
import logging import logging
from fastapi import APIRouter, Request from fastapi import APIRouter, Depends, Request
from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.gateway.authz import require_permission from app.gateway.authz import require_permission
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -100,7 +102,7 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str:
description="Generate short follow-up questions a user might ask next, based on recent conversation context.", description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
) )
@require_permission("threads", "read", owner_check=True) @require_permission("threads", "read", owner_check=True)
async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse: async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request, app_config: AppConfig = Depends(get_config)) -> SuggestionsResponse:
if not body.messages: if not body.messages:
return SuggestionsResponse(suggestions=[]) return SuggestionsResponse(suggestions=[])
@@ -122,7 +124,7 @@ async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions" user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
try: try:
model = create_chat_model(name=body.model_name, thinking_enabled=False) model = create_chat_model(name=body.model_name, thinking_enabled=False, app_config=app_config)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)], config={"run_name": "suggest_agent"}) response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)], config={"run_name": "suggest_agent"})
raw = _extract_response_text(response.content) raw = _extract_response_text(response.content)
suggestions = _parse_json_string_list(raw) or [] suggestions = _parse_json_string_list(raw) or []
+7 -11
View File
@@ -54,6 +54,7 @@ class RunCreateRequest(BaseModel):
after_seconds: float | None = Field(default=None, description="Delayed execution") after_seconds: float | None = Field(default=None, description="Delayed execution")
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy") if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys") feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
follow_up_to_run_id: str | None = Field(default=None, description="Run ID this message follows up on. Auto-detected from latest successful run if not provided.")
class RunResponse(BaseModel): class RunResponse(BaseModel):
@@ -311,15 +312,11 @@ async def list_thread_messages(
if i in last_ai_indices: if i in last_ai_indices:
run_id = msg["run_id"] run_id = msg["run_id"]
fb = feedback_map.get(run_id) fb = feedback_map.get(run_id)
msg["feedback"] = ( msg["feedback"] = {
{ "feedback_id": fb["feedback_id"],
"feedback_id": fb["feedback_id"], "rating": fb["rating"],
"rating": fb["rating"], "comment": fb.get("comment"),
"comment": fb.get("comment"), } if fb else None
}
if fb
else None
)
else: else:
msg["feedback"] = None msg["feedback"] = None
@@ -342,8 +339,7 @@ async def list_run_messages(
""" """
event_store = get_run_event_store(request) event_store = get_run_event_store(request)
rows = await event_store.list_messages_by_run( rows = await event_store.list_messages_by_run(
thread_id, thread_id, run_id,
run_id,
limit=limit + 1, limit=limit + 1,
before_seq=before_seq, before_seq=before_seq,
after_seq=after_seq, after_seq=after_seq,
+1
View File
@@ -13,6 +13,7 @@ matching the LangGraph Platform wire format expected by the
from __future__ import annotations from __future__ import annotations
import logging import logging
import re
import time import time
import uuid import uuid
from typing import Any from typing import Any
+11 -10
View File
@@ -4,11 +4,12 @@ import logging
import os import os
import stat import stat
from fastapi import APIRouter, File, HTTPException, Request, UploadFile from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from pydantic import BaseModel from pydantic import BaseModel
from app.gateway.authz import require_permission from app.gateway.authz import require_permission
from deerflow.config.app_config import get_app_config from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
@@ -60,23 +61,22 @@ def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool:
return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False)) return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False))
def _get_uploads_config_value(key: str, default: object) -> object: def _get_uploads_config_value(app_config: AppConfig, key: str, default: object) -> object:
"""Read a value from the uploads config, supporting dict and attribute access.""" """Read a value from the uploads config, supporting dict and attribute access."""
cfg = get_app_config() uploads_cfg = getattr(app_config, "uploads", None)
uploads_cfg = getattr(cfg, "uploads", None)
if isinstance(uploads_cfg, dict): if isinstance(uploads_cfg, dict):
return uploads_cfg.get(key, default) return uploads_cfg.get(key, default)
return getattr(uploads_cfg, key, default) return getattr(uploads_cfg, key, default)
def _auto_convert_documents_enabled() -> bool: def _auto_convert_documents_enabled(app_config: AppConfig) -> bool:
"""Return whether automatic host-side document conversion is enabled. """Return whether automatic host-side document conversion is enabled.
The secure default is disabled unless an operator explicitly opts in via The secure default is disabled unless an operator explicitly opts in via
uploads.auto_convert_documents in config.yaml. uploads.auto_convert_documents in config.yaml.
""" """
try: try:
raw = _get_uploads_config_value("auto_convert_documents", False) raw = _get_uploads_config_value(app_config, "auto_convert_documents", False)
if isinstance(raw, str): if isinstance(raw, str):
return raw.strip().lower() in {"1", "true", "yes", "on"} return raw.strip().lower() in {"1", "true", "yes", "on"}
return bool(raw) return bool(raw)
@@ -85,11 +85,12 @@ def _auto_convert_documents_enabled() -> bool:
@router.post("", response_model=UploadResponse) @router.post("", response_model=UploadResponse)
@require_permission("threads", "write", owner_check=True, require_existing=False) @require_permission("threads", "write", owner_check=True, require_existing=True)
async def upload_files( async def upload_files(
thread_id: str, thread_id: str,
request: Request, request: Request,
files: list[UploadFile] = File(...), files: list[UploadFile] = File(...),
app_config: AppConfig = Depends(get_config),
) -> UploadResponse: ) -> UploadResponse:
"""Upload multiple files to a thread's uploads directory.""" """Upload multiple files to a thread's uploads directory."""
if not files: if not files:
@@ -102,13 +103,13 @@ async def upload_files(
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
uploaded_files = [] uploaded_files = []
sandbox_provider = get_sandbox_provider() sandbox_provider = get_sandbox_provider(app_config)
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider) sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
sandbox = None sandbox = None
if sync_to_sandbox: if sync_to_sandbox:
sandbox_id = sandbox_provider.acquire(thread_id) sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id) sandbox = sandbox_provider.get(sandbox_id)
auto_convert_documents = _auto_convert_documents_enabled() auto_convert_documents = _auto_convert_documents_enabled(app_config)
for file in files: for file in files:
if not file.filename: if not file.filename:
+18 -1
View File
@@ -8,6 +8,7 @@ frames, and consuming stream bridge events. Router modules
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import dataclasses
import json import json
import logging import logging
import re import re
@@ -17,7 +18,7 @@ from typing import Any
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge from app.gateway.deps import get_run_context, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.utils import sanitize_log_param from app.gateway.utils import sanitize_log_param
from deerflow.runtime import ( from deerflow.runtime import (
END_SENTINEL, END_SENTINEL,
@@ -211,6 +212,21 @@ async def start_run(
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_ disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
# Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run
follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None)
if follow_up_to_run_id is None:
run_store = get_run_store(request)
try:
recent_runs = await run_store.list_by_thread(thread_id, limit=1)
if recent_runs and recent_runs[0].get("status") == "success":
follow_up_to_run_id = recent_runs[0]["run_id"]
except Exception:
pass # Don't block run creation
# Enrich base context with per-run field
if follow_up_to_run_id:
run_ctx = dataclasses.replace(run_ctx, follow_up_to_run_id=follow_up_to_run_id)
try: try:
record = await run_mgr.create_or_reject( record = await run_mgr.create_or_reject(
thread_id, thread_id,
@@ -219,6 +235,7 @@ async def start_run(
metadata=body.metadata or {}, metadata=body.metadata or {},
kwargs={"input": body.input, "config": body.config}, kwargs={"input": body.input, "config": body.config},
multitask_strategy=body.multitask_strategy, multitask_strategy=body.multitask_strategy,
follow_up_to_run_id=follow_up_to_run_id,
) )
except ConflictError as exc: except ConflictError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc raise HTTPException(status_code=409, detail=str(exc)) from exc
@@ -3,6 +3,7 @@ import logging
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph
from deerflow.agents.lead_agent.prompt import apply_prompt_template from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.memory.summarization_hook import memory_flush_hook from deerflow.agents.memory.summarization_hook import memory_flush_hook
@@ -18,9 +19,8 @@ from deerflow.agents.middlewares.tool_error_handling_middleware import build_lea
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
from deerflow.agents.thread_state import ThreadState from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import load_agent_config, validate_agent_name from deerflow.config.agents_config import load_agent_config, validate_agent_name
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import get_memory_config from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.summarization_config import get_summarization_config
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -35,9 +35,8 @@ def _get_runtime_config(config: RunnableConfig) -> dict:
return cfg return cfg
def _resolve_model_name(requested_model_name: str | None = None) -> str: def _resolve_model_name(app_config: AppConfig, requested_model_name: str | None = None) -> str:
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured.""" """Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
app_config = get_app_config()
default_model_name = app_config.models[0].name if app_config.models else None default_model_name = app_config.models[0].name if app_config.models else None
if default_model_name is None: if default_model_name is None:
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.") raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
@@ -50,9 +49,9 @@ def _resolve_model_name(requested_model_name: str | None = None) -> str:
return default_model_name return default_model_name
def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None: def _create_summarization_middleware(app_config: AppConfig) -> DeerFlowSummarizationMiddleware | None:
"""Create and configure the summarization middleware from config.""" """Create and configure the summarization middleware from config."""
config = get_summarization_config() config = app_config.summarization
if not config.enabled: if not config.enabled:
return None return None
@@ -73,9 +72,9 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
# as middleware rather than lead_agent (SummarizationMiddleware is a # as middleware rather than lead_agent (SummarizationMiddleware is a
# LangChain built-in, so we tag the model at creation time). # LangChain built-in, so we tag the model at creation time).
if config.model_name: if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False) model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=app_config)
else: else:
model = create_chat_model(thinking_enabled=False) model = create_chat_model(thinking_enabled=False, app_config=app_config)
model = model.with_config(tags=["middleware:summarize"]) model = model.with_config(tags=["middleware:summarize"])
# Prepare kwargs # Prepare kwargs
@@ -92,14 +91,14 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
kwargs["summary_prompt"] = config.summary_prompt kwargs["summary_prompt"] = config.summary_prompt
hooks: list[BeforeSummarizationHook] = [] hooks: list[BeforeSummarizationHook] = []
if get_memory_config().enabled: if app_config.memory.enabled:
hooks.append(memory_flush_hook) hooks.append(memory_flush_hook)
# The logic below relies on two assumptions holding true: this factory is # The logic below relies on two assumptions holding true: this factory is
# the sole entry point for DeerFlowSummarizationMiddleware, and the runtime # the sole entry point for DeerFlowSummarizationMiddleware, and the runtime
# config is not expected to change after startup. # config is not expected to change after startup.
try: try:
skills_container_path = get_app_config().skills.container_path or "/mnt/skills" skills_container_path = app_config.skills.container_path or "/mnt/skills"
except Exception: except Exception:
logger.exception("Failed to resolve skills container path; falling back to default") logger.exception("Failed to resolve skills container path; falling back to default")
skills_container_path = "/mnt/skills" skills_container_path = "/mnt/skills"
@@ -240,10 +239,18 @@ Being proactive with task management demonstrates thoroughness and ensures all r
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM # ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages # ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
# ClarificationMiddleware should be last to intercept clarification requests after model calls # ClarificationMiddleware should be last to intercept clarification requests after model calls
def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None, custom_middlewares: list[AgentMiddleware] | None = None): def _build_middlewares(
app_config: AppConfig,
config: RunnableConfig,
*,
model_name: str | None,
agent_name: str | None = None,
custom_middlewares: list[AgentMiddleware] | None = None,
):
"""Build middleware chain based on runtime configuration. """Build middleware chain based on runtime configuration.
Args: Args:
app_config: Resolved application config.
config: Runtime configuration containing configurable options like is_plan_mode. config: Runtime configuration containing configurable options like is_plan_mode.
agent_name: If provided, MemoryMiddleware will use per-agent memory storage. agent_name: If provided, MemoryMiddleware will use per-agent memory storage.
custom_middlewares: Optional list of custom middlewares to inject into the chain. custom_middlewares: Optional list of custom middlewares to inject into the chain.
@@ -251,10 +258,10 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
Returns: Returns:
List of middleware instances. List of middleware instances.
""" """
middlewares = build_lead_runtime_middlewares(lazy_init=True) middlewares = build_lead_runtime_middlewares(app_config=app_config, lazy_init=True)
# Add summarization middleware if enabled # Add summarization middleware if enabled
summarization_middleware = _create_summarization_middleware() summarization_middleware = _create_summarization_middleware(app_config)
if summarization_middleware is not None: if summarization_middleware is not None:
middlewares.append(summarization_middleware) middlewares.append(summarization_middleware)
@@ -266,7 +273,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
middlewares.append(todo_list_middleware) middlewares.append(todo_list_middleware)
# Add TokenUsageMiddleware when token_usage tracking is enabled # Add TokenUsageMiddleware when token_usage tracking is enabled
if get_app_config().token_usage.enabled: if app_config.token_usage.enabled:
middlewares.append(TokenUsageMiddleware()) middlewares.append(TokenUsageMiddleware())
# Add TitleMiddleware # Add TitleMiddleware
@@ -277,7 +284,6 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
# Add ViewImageMiddleware only if the current model supports vision. # Add ViewImageMiddleware only if the current model supports vision.
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values. # Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
app_config = get_app_config()
model_config = app_config.get_model_config(model_name) if model_name else None model_config = app_config.get_model_config(model_name) if model_name else None
if model_config is not None and model_config.supports_vision: if model_config is not None and model_config.supports_vision:
middlewares.append(ViewImageMiddleware()) middlewares.append(ViewImageMiddleware())
@@ -306,11 +312,32 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
return middlewares return middlewares
def make_lead_agent(config: RunnableConfig): def make_lead_agent(
config: RunnableConfig,
app_config: AppConfig | None = None,
) -> CompiledStateGraph:
"""Build the lead agent from runtime config.
Args:
config: LangGraph ``RunnableConfig`` carrying per-invocation options
(``thinking_enabled``, ``model_name``, ``is_plan_mode``, etc.).
app_config: Resolved application config. Required for in-process
entry points (DeerFlowClient, Gateway Worker). When omitted we
are being called via ``langgraph.json`` registration and reload
from disk — the LangGraph Server bootstrap path has no other
way to thread the value.
"""
# Lazy import to avoid circular dependency # Lazy import to avoid circular dependency
from deerflow.tools import get_available_tools from deerflow.tools import get_available_tools
from deerflow.tools.builtins import setup_agent from deerflow.tools.builtins import setup_agent
if app_config is None:
# LangGraph Server registers ``make_lead_agent`` via ``langgraph.json``
# and hands us only a ``RunnableConfig``. Reload config from disk
# here — it's a pure function, equivalent to the process-global the
# old code path would have read.
app_config = AppConfig.from_file()
cfg = _get_runtime_config(config) cfg = _get_runtime_config(config)
thinking_enabled = cfg.get("thinking_enabled", True) thinking_enabled = cfg.get("thinking_enabled", True)
@@ -327,9 +354,8 @@ def make_lead_agent(config: RunnableConfig):
agent_model_name = agent_config.model if agent_config and agent_config.model else None agent_model_name = agent_config.model if agent_config and agent_config.model else None
# Final model name resolution: request → agent config → global default, with fallback for unknown names # Final model name resolution: request → agent config → global default, with fallback for unknown names
model_name = _resolve_model_name(requested_model_name or agent_model_name) model_name = _resolve_model_name(app_config, requested_model_name or agent_model_name)
app_config = get_app_config()
model_config = app_config.get_model_config(model_name) model_config = app_config.get_model_config(model_name)
if model_config is None: if model_config is None:
@@ -369,20 +395,22 @@ def make_lead_agent(config: RunnableConfig):
if is_bootstrap: if is_bootstrap:
# Special bootstrap agent with minimal prompt for initial custom agent creation flow # Special bootstrap agent with minimal prompt for initial custom agent creation flow
return create_agent( return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled), model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=app_config),
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) + [setup_agent], tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=app_config) + [setup_agent],
middleware=_build_middlewares(config, model_name=model_name), middleware=_build_middlewares(app_config, config, model_name=model_name),
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])), system_prompt=apply_prompt_template(app_config, subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
state_schema=ThreadState, state_schema=ThreadState,
context_schema=DeerFlowContext,
) )
# Default lead agent (unchanged behavior) # Default lead agent (unchanged behavior)
return create_agent( return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort), model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=app_config),
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled), tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=app_config),
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name), middleware=_build_middlewares(app_config, config, model_name=model_name, agent_name=agent_name),
system_prompt=apply_prompt_template( system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None app_config, subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
), ),
state_schema=ThreadState, state_schema=ThreadState,
context_schema=DeerFlowContext,
) )
@@ -5,6 +5,7 @@ from datetime import datetime
from functools import lru_cache from functools import lru_cache
from deerflow.config.agents_config import load_agent_soul from deerflow.config.agents_config import load_agent_soul
from deerflow.config.app_config import AppConfig
from deerflow.skills import load_skills from deerflow.skills import load_skills
from deerflow.skills.types import Skill from deerflow.skills.types import Skill
from deerflow.subagents import get_available_subagent_names from deerflow.subagents import get_available_subagent_names
@@ -19,19 +20,20 @@ _enabled_skills_refresh_version = 0
_enabled_skills_refresh_event = threading.Event() _enabled_skills_refresh_event = threading.Event()
def _load_enabled_skills_sync() -> list[Skill]: def _load_enabled_skills_sync(app_config: AppConfig | None) -> list[Skill]:
return list(load_skills(enabled_only=True)) return list(load_skills(app_config, enabled_only=True))
def _start_enabled_skills_refresh_thread() -> None: def _start_enabled_skills_refresh_thread(app_config: AppConfig | None) -> None:
threading.Thread( threading.Thread(
target=_refresh_enabled_skills_cache_worker, target=_refresh_enabled_skills_cache_worker,
args=(app_config,),
name="deerflow-enabled-skills-loader", name="deerflow-enabled-skills-loader",
daemon=True, daemon=True,
).start() ).start()
def _refresh_enabled_skills_cache_worker() -> None: def _refresh_enabled_skills_cache_worker(app_config: AppConfig | None) -> None:
global _enabled_skills_cache, _enabled_skills_refresh_active global _enabled_skills_cache, _enabled_skills_refresh_active
while True: while True:
@@ -39,8 +41,8 @@ def _refresh_enabled_skills_cache_worker() -> None:
target_version = _enabled_skills_refresh_version target_version = _enabled_skills_refresh_version
try: try:
skills = _load_enabled_skills_sync() skills = _load_enabled_skills_sync(app_config)
except Exception: except (OSError, ImportError):
logger.exception("Failed to load enabled skills for prompt injection") logger.exception("Failed to load enabled skills for prompt injection")
skills = [] skills = []
@@ -56,7 +58,7 @@ def _refresh_enabled_skills_cache_worker() -> None:
_enabled_skills_cache = None _enabled_skills_cache = None
def _ensure_enabled_skills_cache() -> threading.Event: def _ensure_enabled_skills_cache(app_config: AppConfig | None) -> threading.Event:
global _enabled_skills_refresh_active global _enabled_skills_refresh_active
with _enabled_skills_lock: with _enabled_skills_lock:
@@ -68,11 +70,11 @@ def _ensure_enabled_skills_cache() -> threading.Event:
_enabled_skills_refresh_active = True _enabled_skills_refresh_active = True
_enabled_skills_refresh_event.clear() _enabled_skills_refresh_event.clear()
_start_enabled_skills_refresh_thread() _start_enabled_skills_refresh_thread(app_config)
return _enabled_skills_refresh_event return _enabled_skills_refresh_event
def _invalidate_enabled_skills_cache() -> threading.Event: def _invalidate_enabled_skills_cache(app_config: AppConfig | None) -> threading.Event:
global _enabled_skills_cache, _enabled_skills_refresh_active, _enabled_skills_refresh_version global _enabled_skills_cache, _enabled_skills_refresh_active, _enabled_skills_refresh_version
_get_cached_skills_prompt_section.cache_clear() _get_cached_skills_prompt_section.cache_clear()
@@ -84,30 +86,30 @@ def _invalidate_enabled_skills_cache() -> threading.Event:
return _enabled_skills_refresh_event return _enabled_skills_refresh_event
_enabled_skills_refresh_active = True _enabled_skills_refresh_active = True
_start_enabled_skills_refresh_thread() _start_enabled_skills_refresh_thread(app_config)
return _enabled_skills_refresh_event return _enabled_skills_refresh_event
def prime_enabled_skills_cache() -> None: def prime_enabled_skills_cache(app_config: AppConfig | None = None) -> None:
_ensure_enabled_skills_cache() _ensure_enabled_skills_cache(app_config)
def warm_enabled_skills_cache(timeout_seconds: float = _ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS) -> bool: def warm_enabled_skills_cache(app_config: AppConfig | None = None, timeout_seconds: float = _ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS) -> bool:
if _ensure_enabled_skills_cache().wait(timeout=timeout_seconds): if _ensure_enabled_skills_cache(app_config).wait(timeout=timeout_seconds):
return True return True
logger.warning("Timed out waiting %.1fs for enabled skills cache warm-up", timeout_seconds) logger.warning("Timed out waiting %.1fs for enabled skills cache warm-up", timeout_seconds)
return False return False
def _get_enabled_skills(): def _get_enabled_skills(app_config: AppConfig | None = None):
with _enabled_skills_lock: with _enabled_skills_lock:
cached = _enabled_skills_cache cached = _enabled_skills_cache
if cached is not None: if cached is not None:
return list(cached) return list(cached)
_ensure_enabled_skills_cache() _ensure_enabled_skills_cache(app_config)
return [] return []
@@ -115,12 +117,37 @@ def _skill_mutability_label(category: str) -> str:
return "[custom, editable]" if category == "custom" else "[built-in]" return "[custom, editable]" if category == "custom" else "[built-in]"
def clear_skills_system_prompt_cache() -> None: def clear_skills_system_prompt_cache(app_config: AppConfig | None = None) -> None:
_invalidate_enabled_skills_cache() _invalidate_enabled_skills_cache(app_config)
async def refresh_skills_system_prompt_cache_async() -> None: async def refresh_skills_system_prompt_cache_async(app_config: AppConfig | None = None) -> None:
await asyncio.to_thread(_invalidate_enabled_skills_cache().wait) await asyncio.to_thread(_invalidate_enabled_skills_cache(app_config).wait)
def _reset_skills_system_prompt_cache_state() -> None:
global _enabled_skills_cache, _enabled_skills_refresh_active, _enabled_skills_refresh_version
_get_cached_skills_prompt_section.cache_clear()
with _enabled_skills_lock:
_enabled_skills_cache = None
_enabled_skills_refresh_active = False
_enabled_skills_refresh_version = 0
_enabled_skills_refresh_event.clear()
def _refresh_enabled_skills_cache(app_config: AppConfig | None = None) -> None:
"""Backward-compatible test helper for direct synchronous reload."""
try:
skills = _load_enabled_skills_sync(app_config)
except Exception:
logger.exception("Failed to load enabled skills for prompt injection")
skills = []
with _enabled_skills_lock:
_enabled_skills_cache = skills
_enabled_skills_refresh_active = False
_enabled_skills_refresh_event.set()
def _build_skill_evolution_section(skill_evolution_enabled: bool) -> str: def _build_skill_evolution_section(skill_evolution_enabled: bool) -> str:
@@ -139,7 +166,7 @@ Skip simple one-off tasks.
""" """
def _build_available_subagents_description(available_names: list[str], bash_available: bool) -> str: def _build_available_subagents_description(available_names: list[str], bash_available: bool, app_config: AppConfig) -> str:
"""Dynamically build subagent type descriptions from registry. """Dynamically build subagent type descriptions from registry.
Mirrors Codex's pattern where agent_type_description is dynamically generated Mirrors Codex's pattern where agent_type_description is dynamically generated
@@ -161,7 +188,7 @@ def _build_available_subagents_description(available_names: list[str], bash_avai
if name in builtin_descriptions: if name in builtin_descriptions:
lines.append(f"- **{name}**: {builtin_descriptions[name]}") lines.append(f"- **{name}**: {builtin_descriptions[name]}")
else: else:
config = get_subagent_config(name) config = get_subagent_config(name, app_config)
if config is not None: if config is not None:
desc = config.description.split("\n")[0].strip() # First line only for brevity desc = config.description.split("\n")[0].strip() # First line only for brevity
lines.append(f"- **{name}**: {desc}") lines.append(f"- **{name}**: {desc}")
@@ -169,22 +196,23 @@ def _build_available_subagents_description(available_names: list[str], bash_avai
return "\n".join(lines) return "\n".join(lines)
def _build_subagent_section(max_concurrent: int) -> str: def _build_subagent_section(max_concurrent: int, app_config: AppConfig) -> str:
"""Build the subagent system prompt section with dynamic concurrency limit. """Build the subagent system prompt section with dynamic concurrency limit.
Args: Args:
max_concurrent: Maximum number of concurrent subagent calls allowed per response. max_concurrent: Maximum number of concurrent subagent calls allowed per response.
app_config: Application config used to gate bash availability.
Returns: Returns:
Formatted subagent section string. Formatted subagent section string.
""" """
n = max_concurrent n = max_concurrent
available_names = get_available_subagent_names() available_names = get_available_subagent_names(app_config)
bash_available = "bash" in available_names bash_available = "bash" in available_names
# Dynamically build subagent type descriptions from registry (aligned with Codex's # Dynamically build subagent type descriptions from registry (aligned with Codex's
# agent_type_description pattern where all registered roles are listed in the tool spec). # agent_type_description pattern where all registered roles are listed in the tool spec).
available_subagents = _build_available_subagents_description(available_names, bash_available) available_subagents = _build_available_subagents_description(available_names, bash_available, app_config)
direct_tool_examples = "bash, ls, read_file, web_search, etc." if bash_available else "ls, read_file, web_search, etc." direct_tool_examples = "bash, ls, read_file, web_search, etc." if bash_available else "ls, read_file, web_search, etc."
direct_execution_example = ( direct_execution_example = (
'# User asks: "Run the tests"\n# Thinking: Cannot decompose into parallel sub-tasks\n# → Execute directly\n\nbash("npm test") # Direct execution, not task()' '# User asks: "Run the tests"\n# Thinking: Cannot decompose into parallel sub-tasks\n# → Execute directly\n\nbash("npm test") # Direct execution, not task()'
@@ -511,37 +539,34 @@ combined with a FastAPI gateway for REST API access [citation:FastAPI](https://f
""" """
def _get_memory_context(agent_name: str | None = None) -> str: def _get_memory_context(app_config: AppConfig, agent_name: str | None = None) -> str:
"""Get memory context for injection into system prompt. """Get memory context for injection into system prompt.
Args: Returns an empty string when memory is disabled or the stored memory file
agent_name: If provided, loads per-agent memory. If None, loads global memory. cannot be read/parsed. A corrupt memory.json degrades the prompt to
no-memory; it never kills the agent.
Returns:
Formatted memory context string wrapped in XML tags, or empty string if disabled.
""" """
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
from deerflow.runtime.user_context import get_effective_user_id
memory_config = app_config.memory
if not memory_config.enabled or not memory_config.injection_enabled:
return ""
try: try:
from deerflow.agents.memory import format_memory_for_injection, get_memory_data memory_data = get_memory_data(memory_config, agent_name, user_id=get_effective_user_id())
from deerflow.config.memory_config import get_memory_config except (OSError, ValueError, UnicodeDecodeError):
from deerflow.runtime.user_context import get_effective_user_id logger.exception("Failed to load memory data for prompt injection")
return ""
config = get_memory_config() memory_content = format_memory_for_injection(memory_data, max_tokens=memory_config.max_injection_tokens)
if not config.enabled or not config.injection_enabled: if not memory_content.strip():
return "" return ""
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id()) return f"""<memory>
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
if not memory_content.strip():
return ""
return f"""<memory>
{memory_content} {memory_content}
</memory> </memory>
""" """
except Exception as e:
logger.error("Failed to load memory context: %s", e)
return ""
@lru_cache(maxsize=32) @lru_cache(maxsize=32)
@@ -576,19 +601,12 @@ You have access to skills that provide optimized workflows for specific tasks. E
</skill_system>""" </skill_system>"""
def get_skills_prompt_section(available_skills: set[str] | None = None) -> str: def get_skills_prompt_section(app_config: AppConfig, available_skills: set[str] | None = None) -> str:
"""Generate the skills prompt section with available skills list.""" """Generate the skills prompt section with available skills list."""
skills = _get_enabled_skills() skills = _get_enabled_skills(app_config)
try: container_base_path = app_config.skills.container_path
from deerflow.config import get_app_config skill_evolution_enabled = app_config.skill_evolution.enabled
config = get_app_config()
container_base_path = config.skills.container_path
skill_evolution_enabled = config.skill_evolution.enabled
except Exception:
container_base_path = "/mnt/skills"
skill_evolution_enabled = False
if not skills and not skill_evolution_enabled: if not skills and not skill_evolution_enabled:
return "" return ""
@@ -612,7 +630,7 @@ def get_agent_soul(agent_name: str | None) -> str:
return "" return ""
def get_deferred_tools_prompt_section() -> str: def get_deferred_tools_prompt_section(app_config: AppConfig) -> str:
"""Generate <available-deferred-tools> block for the system prompt. """Generate <available-deferred-tools> block for the system prompt.
Lists only deferred tool names so the agent knows what exists Lists only deferred tool names so the agent knows what exists
@@ -621,12 +639,7 @@ def get_deferred_tools_prompt_section() -> str:
""" """
from deerflow.tools.builtins.tool_search import get_deferred_registry from deerflow.tools.builtins.tool_search import get_deferred_registry
try: if not app_config.tool_search.enabled:
from deerflow.config import get_app_config
if not get_app_config().tool_search.enabled:
return ""
except Exception:
return "" return ""
registry = get_deferred_registry() registry = get_deferred_registry()
@@ -637,15 +650,9 @@ def get_deferred_tools_prompt_section() -> str:
return f"<available-deferred-tools>\n{names}\n</available-deferred-tools>" return f"<available-deferred-tools>\n{names}\n</available-deferred-tools>"
def _build_acp_section() -> str: def _build_acp_section(app_config: AppConfig) -> str:
"""Build the ACP agent prompt section, only if ACP agents are configured.""" """Build the ACP agent prompt section, only if ACP agents are configured."""
try: if not app_config.acp_agents:
from deerflow.config.acp_config import get_acp_agents
agents = get_acp_agents()
if not agents:
return ""
except Exception:
return "" return ""
return ( return (
@@ -657,15 +664,9 @@ def _build_acp_section() -> str:
) )
def _build_custom_mounts_section() -> str: def _build_custom_mounts_section(app_config: AppConfig) -> str:
"""Build a prompt section for explicitly configured sandbox mounts.""" """Build a prompt section for explicitly configured sandbox mounts."""
try: mounts = app_config.sandbox.mounts or []
from deerflow.config import get_app_config
mounts = get_app_config().sandbox.mounts or []
except Exception:
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
return ""
if not mounts: if not mounts:
return "" return ""
@@ -679,13 +680,20 @@ def _build_custom_mounts_section() -> str:
return f"\n**Custom Mounted Directories:**\n{mounts_list}\n- If the user needs files outside `/mnt/user-data`, use these absolute container paths directly when they match the requested directory" return f"\n**Custom Mounted Directories:**\n{mounts_list}\n- If the user needs files outside `/mnt/user-data`, use these absolute container paths directly when they match the requested directory"
def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagents: int = 3, *, agent_name: str | None = None, available_skills: set[str] | None = None) -> str: def apply_prompt_template(
app_config: AppConfig,
subagent_enabled: bool = False,
max_concurrent_subagents: int = 3,
*,
agent_name: str | None = None,
available_skills: set[str] | None = None,
) -> str:
# Get memory context # Get memory context
memory_context = _get_memory_context(agent_name) memory_context = _get_memory_context(app_config, agent_name)
# Include subagent section only if enabled (from runtime parameter) # Include subagent section only if enabled (from runtime parameter)
n = max_concurrent_subagents n = max_concurrent_subagents
subagent_section = _build_subagent_section(n) if subagent_enabled else "" subagent_section = _build_subagent_section(n, app_config) if subagent_enabled else ""
# Add subagent reminder to critical_reminders if enabled # Add subagent reminder to critical_reminders if enabled
subagent_reminder = ( subagent_reminder = (
@@ -706,14 +714,14 @@ def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagen
) )
# Get skills section # Get skills section
skills_section = get_skills_prompt_section(available_skills) skills_section = get_skills_prompt_section(app_config, available_skills)
# Get deferred tools section (tool_search) # Get deferred tools section (tool_search)
deferred_tools_section = get_deferred_tools_prompt_section() deferred_tools_section = get_deferred_tools_prompt_section(app_config)
# Build ACP agent section only if ACP agents are configured # Build ACP agent section only if ACP agents are configured
acp_section = _build_acp_section() acp_section = _build_acp_section(app_config)
custom_mounts_section = _build_custom_mounts_section() custom_mounts_section = _build_custom_mounts_section(app_config)
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section) acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
# Format the prompt with dynamic skills and memory # Format the prompt with dynamic skills and memory
@@ -7,11 +7,17 @@ from dataclasses import dataclass, field
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any
from deerflow.config.memory_config import get_memory_config from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Module-level config pointer set by the middleware that owns the queue.
# The queue runs on a background Timer thread where ``Runtime`` and FastAPI
# request context are not accessible; the enqueuer (which does have runtime
# context) is responsible for plumbing ``AppConfig`` through ``add()``.
@dataclass @dataclass
class ConversationContext: class ConversationContext:
"""Context for a conversation to be processed for memory update.""" """Context for a conversation to be processed for memory update."""
@@ -31,10 +37,21 @@ class MemoryUpdateQueue:
This queue collects conversation contexts and processes them after This queue collects conversation contexts and processes them after
a configurable debounce period. Multiple conversations received within a configurable debounce period. Multiple conversations received within
the debounce window are batched together. the debounce window are batched together.
The queue captures an ``AppConfig`` reference at construction time and
reuses it for the MemoryUpdater it spawns. Callers must construct a
fresh queue when the config changes rather than reaching into a global.
""" """
def __init__(self): def __init__(self, app_config: AppConfig):
"""Initialize the memory update queue.""" """Initialize the memory update queue.
Args:
app_config: Application config. The queue reads its own
``memory`` section for debounce timing and hands the full
config to :class:`MemoryUpdater`.
"""
self._app_config = app_config
self._queue: list[ConversationContext] = [] self._queue: list[ConversationContext] = []
self._lock = threading.Lock() self._lock = threading.Lock()
self._timer: threading.Timer | None = None self._timer: threading.Timer | None = None
@@ -49,19 +66,8 @@ class MemoryUpdateQueue:
correction_detected: bool = False, correction_detected: bool = False,
reinforcement_detected: bool = False, reinforcement_detected: bool = False,
) -> None: ) -> None:
"""Add a conversation to the update queue. """Add a conversation to the update queue."""
config = self._app_config.memory
Args:
thread_id: The thread ID.
messages: The conversation messages.
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
user_id: The user ID captured at enqueue time. Stored in ConversationContext so it
survives the threading.Timer boundary (ContextVar does not propagate across
raw threads).
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
"""
config = get_memory_config()
if not config.enabled: if not config.enabled:
return return
@@ -88,7 +94,7 @@ class MemoryUpdateQueue:
reinforcement_detected: bool = False, reinforcement_detected: bool = False,
) -> None: ) -> None:
"""Add a conversation and start processing immediately in the background.""" """Add a conversation and start processing immediately in the background."""
config = get_memory_config() config = self._app_config.memory
if not config.enabled: if not config.enabled:
return return
@@ -111,7 +117,7 @@ class MemoryUpdateQueue:
thread_id: str, thread_id: str,
messages: list[Any], messages: list[Any],
agent_name: str | None, agent_name: str | None,
user_id: str | None, user_id: str | None = None,
correction_detected: bool, correction_detected: bool,
reinforcement_detected: bool, reinforcement_detected: bool,
) -> None: ) -> None:
@@ -135,7 +141,7 @@ class MemoryUpdateQueue:
def _reset_timer(self) -> None: def _reset_timer(self) -> None:
"""Reset the debounce timer.""" """Reset the debounce timer."""
config = get_memory_config() config = self._app_config.memory
self._schedule_timer(config.debounce_seconds) self._schedule_timer(config.debounce_seconds)
logger.debug("Memory update timer set for %ss", config.debounce_seconds) logger.debug("Memory update timer set for %ss", config.debounce_seconds)
@@ -175,7 +181,7 @@ class MemoryUpdateQueue:
logger.info("Processing %d queued memory updates", len(contexts_to_process)) logger.info("Processing %d queued memory updates", len(contexts_to_process))
try: try:
updater = MemoryUpdater() updater = MemoryUpdater(self._app_config)
for context in contexts_to_process: for context in contexts_to_process:
try: try:
@@ -247,31 +253,35 @@ class MemoryUpdateQueue:
return self._processing return self._processing
# Global singleton instance # Queues keyed by ``id(AppConfig)`` so tests and multi-client setups with
_memory_queue: MemoryUpdateQueue | None = None # distinct configs do not share a debounce queue.
_memory_queues: dict[int, MemoryUpdateQueue] = {}
_queue_lock = threading.Lock() _queue_lock = threading.Lock()
def get_memory_queue() -> MemoryUpdateQueue: def get_memory_queue(app_config: AppConfig) -> MemoryUpdateQueue:
"""Get the global memory update queue singleton. """Get or create the memory update queue for the given app config."""
key = id(app_config)
Returns:
The memory update queue instance.
"""
global _memory_queue
with _queue_lock: with _queue_lock:
if _memory_queue is None: queue = _memory_queues.get(key)
_memory_queue = MemoryUpdateQueue() if queue is None:
return _memory_queue queue = MemoryUpdateQueue(app_config)
_memory_queues[key] = queue
return queue
def reset_memory_queue() -> None: def reset_memory_queue(app_config: AppConfig | None = None) -> None:
"""Reset the global memory queue. """Reset memory queue(s).
This is useful for testing. Pass an ``app_config`` to reset only its queue, or omit to reset all
(useful at test teardown).
""" """
global _memory_queue
with _queue_lock: with _queue_lock:
if _memory_queue is not None: if app_config is not None:
_memory_queue.clear() queue = _memory_queues.pop(id(app_config), None)
_memory_queue = None if queue is not None:
queue.clear()
return
for queue in _memory_queues.values():
queue.clear()
_memory_queues.clear()
@@ -10,7 +10,7 @@ from pathlib import Path
from typing import Any from typing import Any
from deerflow.config.agents_config import AGENT_NAME_PATTERN from deerflow.config.agents_config import AGENT_NAME_PATTERN
from deerflow.config.memory_config import get_memory_config from deerflow.config.memory_config import MemoryConfig
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -62,8 +62,15 @@ class MemoryStorage(abc.ABC):
class FileMemoryStorage(MemoryStorage): class FileMemoryStorage(MemoryStorage):
"""File-based memory storage provider.""" """File-based memory storage provider."""
def __init__(self): def __init__(self, memory_config: MemoryConfig):
"""Initialize the file memory storage.""" """Initialize the file memory storage.
Args:
memory_config: Memory configuration (storage_path etc.). Stored on
the instance so per-request lookups don't need to reach for
ambient state.
"""
self._memory_config = memory_config
# Per-user/agent memory cache: keyed by (user_id, agent_name) tuple (None = global) # Per-user/agent memory cache: keyed by (user_id, agent_name) tuple (None = global)
# Value: (memory_data, file_mtime) # Value: (memory_data, file_mtime)
self._memory_cache: dict[tuple[str | None, str | None], tuple[dict[str, Any], float | None]] = {} self._memory_cache: dict[tuple[str | None, str | None], tuple[dict[str, Any], float | None]] = {}
@@ -83,11 +90,11 @@ class FileMemoryStorage(MemoryStorage):
def _get_memory_file_path(self, agent_name: str | None = None, *, user_id: str | None = None) -> Path: def _get_memory_file_path(self, agent_name: str | None = None, *, user_id: str | None = None) -> Path:
"""Get the path to the memory file.""" """Get the path to the memory file."""
config = self._memory_config
if user_id is not None: if user_id is not None:
if agent_name is not None: if agent_name is not None:
self._validate_agent_name(agent_name) self._validate_agent_name(agent_name)
return get_paths().user_agent_memory_file(user_id, agent_name) return get_paths().user_agent_memory_file(user_id, agent_name)
config = get_memory_config()
if config.storage_path and Path(config.storage_path).is_absolute(): if config.storage_path and Path(config.storage_path).is_absolute():
return Path(config.storage_path) return Path(config.storage_path)
return get_paths().user_memory_file(user_id) return get_paths().user_memory_file(user_id)
@@ -95,7 +102,6 @@ class FileMemoryStorage(MemoryStorage):
if agent_name is not None: if agent_name is not None:
self._validate_agent_name(agent_name) self._validate_agent_name(agent_name)
return get_paths().agent_memory_file(agent_name) return get_paths().agent_memory_file(agent_name)
config = get_memory_config()
if config.storage_path: if config.storage_path:
p = Path(config.storage_path) p = Path(config.storage_path)
return p if p.is_absolute() else get_paths().base_dir / p return p if p.is_absolute() else get_paths().base_dir / p
@@ -116,20 +122,16 @@ class FileMemoryStorage(MemoryStorage):
logger.warning("Failed to load memory file: %s", e) logger.warning("Failed to load memory file: %s", e)
return create_empty_memory() return create_empty_memory()
@staticmethod
def _cache_key(agent_name: str | None = None, *, user_id: str | None = None) -> tuple[str | None, str | None]:
return (user_id, agent_name)
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Load memory data (cached with file modification time check).""" """Load memory data (cached with file modification time check)."""
file_path = self._get_memory_file_path(agent_name, user_id=user_id) file_path = self._get_memory_file_path(agent_name, user_id=user_id)
cache_key = self._cache_key(agent_name, user_id=user_id)
try: try:
current_mtime = file_path.stat().st_mtime if file_path.exists() else None current_mtime = file_path.stat().st_mtime if file_path.exists() else None
except OSError: except OSError:
current_mtime = None current_mtime = None
cache_key = (user_id, agent_name)
with self._cache_lock: with self._cache_lock:
cached = self._memory_cache.get(cache_key) cached = self._memory_cache.get(cache_key)
if cached is not None and cached[1] == current_mtime: if cached is not None and cached[1] == current_mtime:
@@ -146,13 +148,13 @@ class FileMemoryStorage(MemoryStorage):
"""Reload memory data from file, forcing cache invalidation.""" """Reload memory data from file, forcing cache invalidation."""
file_path = self._get_memory_file_path(agent_name, user_id=user_id) file_path = self._get_memory_file_path(agent_name, user_id=user_id)
memory_data = self._load_memory_from_file(agent_name, user_id=user_id) memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
cache_key = self._cache_key(agent_name, user_id=user_id)
try: try:
mtime = file_path.stat().st_mtime if file_path.exists() else None mtime = file_path.stat().st_mtime if file_path.exists() else None
except OSError: except OSError:
mtime = None mtime = None
cache_key = (user_id, agent_name)
with self._cache_lock: with self._cache_lock:
self._memory_cache[cache_key] = (memory_data, mtime) self._memory_cache[cache_key] = (memory_data, mtime)
return memory_data return memory_data
@@ -160,7 +162,6 @@ class FileMemoryStorage(MemoryStorage):
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool: def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
"""Save memory data to file and update cache.""" """Save memory data to file and update cache."""
file_path = self._get_memory_file_path(agent_name, user_id=user_id) file_path = self._get_memory_file_path(agent_name, user_id=user_id)
cache_key = self._cache_key(agent_name, user_id=user_id)
try: try:
file_path.parent.mkdir(parents=True, exist_ok=True) file_path.parent.mkdir(parents=True, exist_ok=True)
@@ -180,6 +181,7 @@ class FileMemoryStorage(MemoryStorage):
except OSError: except OSError:
mtime = None mtime = None
cache_key = (user_id, agent_name)
with self._cache_lock: with self._cache_lock:
self._memory_cache[cache_key] = (memory_data, mtime) self._memory_cache[cache_key] = (memory_data, mtime)
logger.info("Memory saved to %s", file_path) logger.info("Memory saved to %s", file_path)
@@ -189,23 +191,31 @@ class FileMemoryStorage(MemoryStorage):
return False return False
_storage_instance: MemoryStorage | None = None # Instances keyed by (storage_class_path, id(memory_config)) so tests can
# construct isolated storages and multi-client setups with different configs
# don't collide on a single process-wide singleton.
_storage_instances: dict[tuple[str, int], MemoryStorage] = {}
_storage_lock = threading.Lock() _storage_lock = threading.Lock()
def get_memory_storage() -> MemoryStorage: def get_memory_storage(memory_config: MemoryConfig) -> MemoryStorage:
"""Get the configured memory storage instance.""" """Get the configured memory storage instance.
global _storage_instance
if _storage_instance is not None: Caches one instance per ``(storage_class, memory_config)`` pair. In
return _storage_instance single-config deployments this collapses to one instance; in multi-client
or test scenarios each config gets its own storage.
"""
key = (memory_config.storage_class, id(memory_config))
existing = _storage_instances.get(key)
if existing is not None:
return existing
with _storage_lock: with _storage_lock:
if _storage_instance is not None: existing = _storage_instances.get(key)
return _storage_instance if existing is not None:
return existing
config = get_memory_config()
storage_class_path = config.storage_class
storage_class_path = memory_config.storage_class
try: try:
module_path, class_name = storage_class_path.rsplit(".", 1) module_path, class_name = storage_class_path.rsplit(".", 1)
import importlib import importlib
@@ -219,13 +229,14 @@ def get_memory_storage() -> MemoryStorage:
if not issubclass(storage_class, MemoryStorage): if not issubclass(storage_class, MemoryStorage):
raise TypeError(f"Configured memory storage '{storage_class_path}' is not a subclass of MemoryStorage") raise TypeError(f"Configured memory storage '{storage_class_path}' is not a subclass of MemoryStorage")
_storage_instance = storage_class() instance = storage_class(memory_config)
except Exception as e: except Exception as e:
logger.error( logger.error(
"Failed to load memory storage %s, falling back to FileMemoryStorage: %s", "Failed to load memory storage %s, falling back to FileMemoryStorage: %s",
storage_class_path, storage_class_path,
e, e,
) )
_storage_instance = FileMemoryStorage() instance = FileMemoryStorage(memory_config)
return _storage_instance _storage_instances[key] = instance
return instance
@@ -5,12 +5,19 @@ from __future__ import annotations
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
from deerflow.agents.memory.queue import get_memory_queue from deerflow.agents.memory.queue import get_memory_queue
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
from deerflow.config.memory_config import get_memory_config from deerflow.config.app_config import AppConfig
def memory_flush_hook(event: SummarizationEvent) -> None: def memory_flush_hook(event: SummarizationEvent) -> None:
"""Flush messages about to be summarized into the memory queue.""" """Flush messages about to be summarized into the memory queue.
if not get_memory_config().enabled or not event.thread_id:
Reads ``AppConfig`` from disk on every invocation. This hook is fired by
``SummarizationMiddleware`` which has no ergonomic way to thread an
explicit ``app_config`` through; ``AppConfig.from_file()`` is a pure load
so the cost is acceptable for this rare pre-summarization callback.
"""
app_config = AppConfig.from_file()
if not app_config.memory.enabled or not event.thread_id:
return return
filtered_messages = filter_messages_for_memory(list(event.messages_to_summarize)) filtered_messages = filter_messages_for_memory(list(event.messages_to_summarize))
@@ -21,7 +28,7 @@ def memory_flush_hook(event: SummarizationEvent) -> None:
correction_detected = detect_correction(filtered_messages) correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages) reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
queue = get_memory_queue() queue = get_memory_queue(app_config)
queue.add_nowait( queue.add_nowait(
thread_id=event.thread_id, thread_id=event.thread_id,
messages=filtered_messages, messages=filtered_messages,
@@ -21,7 +21,8 @@ from deerflow.agents.memory.storage import (
get_memory_storage, get_memory_storage,
utc_now_iso_z, utc_now_iso_z,
) )
from deerflow.config.memory_config import get_memory_config from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -38,45 +39,33 @@ def _create_empty_memory() -> dict[str, Any]:
return create_empty_memory() return create_empty_memory()
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool: def _save_memory_to_file(memory_config: MemoryConfig, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
"""Backward-compatible wrapper around the configured memory storage save path.""" """Save via the configured memory storage."""
return get_memory_storage().save(memory_data, agent_name, user_id=user_id) return get_memory_storage(memory_config).save(memory_data, agent_name, user_id=user_id)
def get_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: def get_memory_data(memory_config: MemoryConfig, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Get the current memory data via storage provider.""" """Get the current memory data via storage provider."""
return get_memory_storage().load(agent_name, user_id=user_id) return get_memory_storage(memory_config).load(agent_name, user_id=user_id)
def reload_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: def reload_memory_data(memory_config: MemoryConfig, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Reload memory data via storage provider.""" """Reload memory data via storage provider."""
return get_memory_storage().reload(agent_name, user_id=user_id) return get_memory_storage(memory_config).reload(agent_name, user_id=user_id)
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: def import_memory_data(memory_config: MemoryConfig, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Persist imported memory data via storage provider. """Persist imported memory data via storage provider."""
storage = get_memory_storage(memory_config)
Args:
memory_data: Full memory payload to persist.
agent_name: If provided, imports into per-agent memory.
user_id: If provided, scopes memory to a specific user.
Returns:
The saved memory data after storage normalization.
Raises:
OSError: If persisting the imported memory fails.
"""
storage = get_memory_storage()
if not storage.save(memory_data, agent_name, user_id=user_id): if not storage.save(memory_data, agent_name, user_id=user_id):
raise OSError("Failed to save imported memory data") raise OSError("Failed to save imported memory data")
return storage.load(agent_name, user_id=user_id) return storage.load(agent_name, user_id=user_id)
def clear_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: def clear_memory_data(memory_config: MemoryConfig, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Clear all stored memory data and persist an empty structure.""" """Clear all stored memory data and persist an empty structure."""
cleared_memory = create_empty_memory() cleared_memory = create_empty_memory()
if not _save_memory_to_file(cleared_memory, agent_name, user_id=user_id): if not _save_memory_to_file(memory_config, cleared_memory, agent_name, user_id=user_id):
raise OSError("Failed to save cleared memory data") raise OSError("Failed to save cleared memory data")
return cleared_memory return cleared_memory
@@ -89,6 +78,7 @@ def _validate_confidence(confidence: float) -> float:
def create_memory_fact( def create_memory_fact(
memory_config: MemoryConfig,
content: str, content: str,
category: str = "context", category: str = "context",
confidence: float = 0.5, confidence: float = 0.5,
@@ -104,7 +94,7 @@ def create_memory_fact(
normalized_category = category.strip() or "context" normalized_category = category.strip() or "context"
validated_confidence = _validate_confidence(confidence) validated_confidence = _validate_confidence(confidence)
now = utc_now_iso_z() now = utc_now_iso_z()
memory_data = get_memory_data(agent_name, user_id=user_id) memory_data = get_memory_data(memory_config, agent_name, user_id=user_id)
updated_memory = dict(memory_data) updated_memory = dict(memory_data)
facts = list(memory_data.get("facts", [])) facts = list(memory_data.get("facts", []))
facts.append( facts.append(
@@ -119,15 +109,15 @@ def create_memory_fact(
) )
updated_memory["facts"] = facts updated_memory["facts"] = facts
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id): if not _save_memory_to_file(memory_config, updated_memory, agent_name, user_id=user_id):
raise OSError("Failed to save memory data after creating fact") raise OSError("Failed to save memory data after creating fact")
return updated_memory return updated_memory
def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]: def delete_memory_fact(memory_config: MemoryConfig, fact_id: str, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Delete a fact by its id and persist the updated memory data.""" """Delete a fact by its id and persist the updated memory data."""
memory_data = get_memory_data(agent_name, user_id=user_id) memory_data = get_memory_data(memory_config, agent_name, user_id=user_id)
facts = memory_data.get("facts", []) facts = memory_data.get("facts", [])
updated_facts = [fact for fact in facts if fact.get("id") != fact_id] updated_facts = [fact for fact in facts if fact.get("id") != fact_id]
if len(updated_facts) == len(facts): if len(updated_facts) == len(facts):
@@ -136,13 +126,14 @@ def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id:
updated_memory = dict(memory_data) updated_memory = dict(memory_data)
updated_memory["facts"] = updated_facts updated_memory["facts"] = updated_facts
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id): if not _save_memory_to_file(memory_config, updated_memory, agent_name, user_id=user_id):
raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'") raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'")
return updated_memory return updated_memory
def update_memory_fact( def update_memory_fact(
memory_config: MemoryConfig,
fact_id: str, fact_id: str,
content: str | None = None, content: str | None = None,
category: str | None = None, category: str | None = None,
@@ -152,7 +143,7 @@ def update_memory_fact(
user_id: str | None = None, user_id: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Update an existing fact and persist the updated memory data.""" """Update an existing fact and persist the updated memory data."""
memory_data = get_memory_data(agent_name, user_id=user_id) memory_data = get_memory_data(memory_config, agent_name, user_id=user_id)
updated_memory = dict(memory_data) updated_memory = dict(memory_data)
updated_facts: list[dict[str, Any]] = [] updated_facts: list[dict[str, Any]] = []
found = False found = False
@@ -179,7 +170,7 @@ def update_memory_fact(
updated_memory["facts"] = updated_facts updated_memory["facts"] = updated_facts
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id): if not _save_memory_to_file(memory_config, updated_memory, agent_name, user_id=user_id):
raise OSError(f"Failed to save memory data after updating fact '{fact_id}'") raise OSError(f"Failed to save memory data after updating fact '{fact_id}'")
return updated_memory return updated_memory
@@ -304,19 +295,25 @@ def _fact_content_key(content: Any) -> str | None:
class MemoryUpdater: class MemoryUpdater:
"""Updates memory using LLM based on conversation context.""" """Updates memory using LLM based on conversation context."""
def __init__(self, model_name: str | None = None): def __init__(self, app_config: AppConfig, model_name: str | None = None):
"""Initialize the memory updater. """Initialize the memory updater.
Args: Args:
app_config: Application config (the updater needs both ``memory``
section for behavior and the full config for ``create_chat_model``).
model_name: Optional model name to use. If None, uses config or default. model_name: Optional model name to use. If None, uses config or default.
""" """
self._app_config = app_config
self._model_name = model_name self._model_name = model_name
@property
def _memory_config(self) -> MemoryConfig:
return self._app_config.memory
def _get_model(self): def _get_model(self):
"""Get the model for memory updates.""" """Get the model for memory updates."""
config = get_memory_config() model_name = self._model_name or self._memory_config.model_name
model_name = self._model_name or config.model_name return create_chat_model(name=model_name, thinking_enabled=False, app_config=self._app_config)
return create_chat_model(name=model_name, thinking_enabled=False)
def _build_correction_hint( def _build_correction_hint(
self, self,
@@ -349,13 +346,14 @@ class MemoryUpdater:
agent_name: str | None, agent_name: str | None,
correction_detected: bool, correction_detected: bool,
reinforcement_detected: bool, reinforcement_detected: bool,
user_id: str | None = None,
) -> tuple[dict[str, Any], str] | None: ) -> tuple[dict[str, Any], str] | None:
"""Load memory and build the update prompt for a conversation.""" """Load memory and build the update prompt for a conversation."""
config = get_memory_config() config = self._memory_config
if not config.enabled or not messages: if not config.enabled or not messages:
return None return None
current_memory = get_memory_data(agent_name) current_memory = get_memory_data(config, agent_name, user_id=user_id)
conversation_text = format_conversation_for_update(messages) conversation_text = format_conversation_for_update(messages)
if not conversation_text.strip(): if not conversation_text.strip():
return None return None
@@ -377,6 +375,7 @@ class MemoryUpdater:
response_content: Any, response_content: Any,
thread_id: str | None, thread_id: str | None,
agent_name: str | None, agent_name: str | None,
user_id: str | None = None,
) -> bool: ) -> bool:
"""Parse the model response, apply updates, and persist memory.""" """Parse the model response, apply updates, and persist memory."""
response_text = _extract_text(response_content).strip() response_text = _extract_text(response_content).strip()
@@ -390,7 +389,7 @@ class MemoryUpdater:
# cannot corrupt the still-cached original object reference. # cannot corrupt the still-cached original object reference.
updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id) updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id)
updated_memory = _strip_upload_mentions_from_memory(updated_memory) updated_memory = _strip_upload_mentions_from_memory(updated_memory)
return get_memory_storage().save(updated_memory, agent_name) return get_memory_storage(self._memory_config).save(updated_memory, agent_name, user_id=user_id)
async def aupdate_memory( async def aupdate_memory(
self, self,
@@ -399,6 +398,7 @@ class MemoryUpdater:
agent_name: str | None = None, agent_name: str | None = None,
correction_detected: bool = False, correction_detected: bool = False,
reinforcement_detected: bool = False, reinforcement_detected: bool = False,
user_id: str | None = None,
) -> bool: ) -> bool:
"""Update memory asynchronously based on conversation messages.""" """Update memory asynchronously based on conversation messages."""
try: try:
@@ -408,6 +408,7 @@ class MemoryUpdater:
agent_name=agent_name, agent_name=agent_name,
correction_detected=correction_detected, correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected, reinforcement_detected=reinforcement_detected,
user_id=user_id,
) )
if prepared is None: if prepared is None:
return False return False
@@ -421,6 +422,7 @@ class MemoryUpdater:
response_content=response.content, response_content=response.content,
thread_id=thread_id, thread_id=thread_id,
agent_name=agent_name, agent_name=agent_name,
user_id=user_id,
) )
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e) logger.warning("Failed to parse LLM response for memory update: %s", e)
@@ -451,15 +453,78 @@ class MemoryUpdater:
Returns: Returns:
True if update was successful, False otherwise. True if update was successful, False otherwise.
""" """
return _run_async_update_sync( config = self._memory_config
self.aupdate_memory( if not config.enabled:
messages=messages, return False
thread_id=thread_id,
agent_name=agent_name, if not messages:
correction_detected=correction_detected, return False
reinforcement_detected=reinforcement_detected,
try:
# Get current memory
current_memory = get_memory_data(config, agent_name, user_id=user_id)
# Format conversation for prompt
conversation_text = format_conversation_for_update(messages)
if not conversation_text.strip():
return False
# Build prompt
correction_hint = ""
if correction_detected:
correction_hint = (
"IMPORTANT: Explicit correction signals were detected in this conversation. "
"Pay special attention to what the agent got wrong, what the user corrected, "
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
if reinforcement_detected:
reinforcement_hint = (
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
"The user explicitly confirmed the agent's approach was correct or helpful. "
"Record the confirmed approach, style, or preference as a fact with category "
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
)
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
conversation=conversation_text,
correction_hint=correction_hint,
) )
)
# Call LLM
model = self._get_model()
response = model.invoke(prompt)
response_text = _extract_text(response.content).strip()
# Parse response
# Remove markdown code blocks if present
if response_text.startswith("```"):
lines = response_text.split("\n")
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
update_data = json.loads(response_text)
# Apply updates
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
# Strip file-upload mentions from all summaries before saving.
# Uploaded files are session-scoped and won't exist in future sessions,
# so recording upload events in long-term memory causes the agent to
# try (and fail) to locate those files in subsequent conversations.
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
# Save
return get_memory_storage(config).save(updated_memory, agent_name, user_id=user_id)
except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e)
return False
except Exception as e:
logger.exception("Memory update failed: %s", e)
return False
def _apply_updates( def _apply_updates(
self, self,
@@ -477,7 +542,7 @@ class MemoryUpdater:
Returns: Returns:
Updated memory data. Updated memory data.
""" """
config = get_memory_config() config = self._memory_config
now = utc_now_iso_z() now = utc_now_iso_z()
# Update user sections # Update user sections
@@ -20,7 +20,7 @@ from langchain.agents.middleware.types import (
from langchain_core.messages import AIMessage from langchain_core.messages import AIMessage
from langgraph.errors import GraphBubbleUp from langgraph.errors import GraphBubbleUp
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -78,7 +78,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
# Load Circuit Breaker configs from app config if available, fall back to defaults # Load Circuit Breaker configs from app config if available, fall back to defaults
try: try:
app_config = get_app_config() app_config = AppConfig.from_file()
self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold
self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec
except (FileNotFoundError, RuntimeError): except (FileNotFoundError, RuntimeError):
@@ -25,6 +25,8 @@ from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.config.deer_flow_context import DeerFlowContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Defaults — can be overridden via constructor # Defaults — can be overridden via constructor
@@ -181,12 +183,9 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set) self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
def _get_thread_id(self, runtime: Runtime) -> str: def _get_thread_id(self, runtime: Runtime[DeerFlowContext]) -> str:
"""Extract thread_id from runtime context for per-thread tracking.""" """Extract thread_id from runtime context for per-thread tracking."""
thread_id = runtime.context.get("thread_id") if runtime.context else None return runtime.context.thread_id or "default"
if thread_id:
return thread_id
return "default"
def _evict_if_needed(self) -> None: def _evict_if_needed(self) -> None:
"""Evict least recently used threads if over the limit. """Evict least recently used threads if over the limit.
@@ -367,11 +366,11 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
return None return None
@override @override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: def after_model(self, state: AgentState, runtime: Runtime[DeerFlowContext]) -> dict | None:
return self._apply(state, runtime) return self._apply(state, runtime)
@override @override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None: async def aafter_model(self, state: AgentState, runtime: Runtime[DeerFlowContext]) -> dict | None:
return self._apply(state, runtime) return self._apply(state, runtime)
def reset(self, thread_id: str | None = None) -> None: def reset(self, thread_id: str | None = None) -> None:
@@ -5,12 +5,11 @@ from typing import override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
from deerflow.agents.memory.queue import get_memory_queue from deerflow.agents.memory.queue import get_memory_queue
from deerflow.config.memory_config import get_memory_config from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.runtime.user_context import get_effective_user_id from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -44,7 +43,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
self._agent_name = agent_name self._agent_name = agent_name
@override @override
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None: def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
"""Queue conversation for memory update after agent completes. """Queue conversation for memory update after agent completes.
Args: Args:
@@ -54,15 +53,11 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
Returns: Returns:
None (no state changes needed from this middleware). None (no state changes needed from this middleware).
""" """
config = get_memory_config() memory_config = runtime.context.app_config.memory
if not config.enabled: if not memory_config.enabled:
return None return None
# Get thread ID from runtime context first, then fall back to LangGraph's configurable metadata thread_id = runtime.context.thread_id
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id is None:
config_data = get_config()
thread_id = config_data.get("configurable", {}).get("thread_id")
if not thread_id: if not thread_id:
logger.debug("No thread_id in context, skipping memory update") logger.debug("No thread_id in context, skipping memory update")
return None return None
@@ -91,7 +86,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
# threading.Timer fires on a different thread where ContextVar values are not # threading.Timer fires on a different thread where ContextVar values are not
# propagated, so we must store user_id explicitly in ConversationContext. # propagated, so we must store user_id explicitly in ConversationContext.
user_id = get_effective_user_id() user_id = get_effective_user_id()
queue = get_memory_queue() queue = get_memory_queue(runtime.context.app_config)
queue.add( queue.add(
thread_id=thread_id, thread_id=thread_id,
messages=filtered_messages, messages=filtered_messages,
@@ -4,11 +4,10 @@ from typing import NotRequired, override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.config import get_config
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.agents.thread_state import ThreadDataState from deerflow.agents.thread_state import ThreadDataState
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.paths import Paths, get_paths from deerflow.config.paths import Paths, get_paths
from deerflow.runtime.user_context import get_effective_user_id from deerflow.runtime.user_context import get_effective_user_id
@@ -79,14 +78,10 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
return self._get_thread_paths(thread_id, user_id=user_id) return self._get_thread_paths(thread_id, user_id=user_id)
@override @override
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None: def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
context = runtime.context or {} thread_id = runtime.context.thread_id
thread_id = context.get("thread_id")
if thread_id is None:
config = get_config()
thread_id = config.get("configurable", {}).get("thread_id")
if thread_id is None: if not thread_id:
raise ValueError("Thread ID is required in runtime context or config.configurable") raise ValueError("Thread ID is required in runtime context or config.configurable")
user_id = get_effective_user_id() user_id = get_effective_user_id()
@@ -9,7 +9,9 @@ from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config from langgraph.config import get_config
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.config.title_config import get_title_config from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.title_config import TitleConfig
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -45,10 +47,9 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
return "" return ""
def _should_generate_title(self, state: TitleMiddlewareState) -> bool: def _should_generate_title(self, state: TitleMiddlewareState, title_config: TitleConfig) -> bool:
"""Check if we should generate a title for this thread.""" """Check if we should generate a title for this thread."""
config = get_title_config() if not title_config.enabled:
if not config.enabled:
return False return False
# Check if thread already has a title in state # Check if thread already has a title in state
@@ -67,12 +68,11 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
# Generate title after first complete exchange # Generate title after first complete exchange
return len(user_messages) == 1 and len(assistant_messages) >= 1 return len(user_messages) == 1 and len(assistant_messages) >= 1
def _build_title_prompt(self, state: TitleMiddlewareState) -> tuple[str, str]: def _build_title_prompt(self, state: TitleMiddlewareState, title_config: TitleConfig) -> tuple[str, str]:
"""Extract user/assistant messages and build the title prompt. """Extract user/assistant messages and build the title prompt.
Returns (prompt_string, user_msg) so callers can use user_msg as fallback. Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
""" """
config = get_title_config()
messages = state.get("messages", []) messages = state.get("messages", [])
user_msg_content = next((m.content for m in messages if m.type == "human"), "") user_msg_content = next((m.content for m in messages if m.type == "human"), "")
@@ -81,8 +81,8 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
user_msg = self._normalize_content(user_msg_content) user_msg = self._normalize_content(user_msg_content)
assistant_msg = self._strip_think_tags(self._normalize_content(assistant_msg_content)) assistant_msg = self._strip_think_tags(self._normalize_content(assistant_msg_content))
prompt = config.prompt_template.format( prompt = title_config.prompt_template.format(
max_words=config.max_words, max_words=title_config.max_words,
user_msg=user_msg[:500], user_msg=user_msg[:500],
assistant_msg=assistant_msg[:500], assistant_msg=assistant_msg[:500],
) )
@@ -92,17 +92,15 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
"""Remove <think>...</think> blocks emitted by reasoning models (e.g. minimax, DeepSeek-R1).""" """Remove <think>...</think> blocks emitted by reasoning models (e.g. minimax, DeepSeek-R1)."""
return re.sub(r"<think>[\s\S]*?</think>", "", text, flags=re.IGNORECASE).strip() return re.sub(r"<think>[\s\S]*?</think>", "", text, flags=re.IGNORECASE).strip()
def _parse_title(self, content: object) -> str: def _parse_title(self, content: object, title_config: TitleConfig) -> str:
"""Normalize model output into a clean title string.""" """Normalize model output into a clean title string."""
config = get_title_config()
title_content = self._normalize_content(content) title_content = self._normalize_content(content)
title_content = self._strip_think_tags(title_content) title_content = self._strip_think_tags(title_content)
title = title_content.strip().strip('"').strip("'") title = title_content.strip().strip('"').strip("'")
return title[: config.max_chars] if len(title) > config.max_chars else title return title[: title_config.max_chars] if len(title) > title_config.max_chars else title
def _fallback_title(self, user_msg: str) -> str: def _fallback_title(self, user_msg: str, title_config: TitleConfig) -> str:
config = get_title_config() fallback_chars = min(title_config.max_chars, 50)
fallback_chars = min(config.max_chars, 50)
if len(user_msg) > fallback_chars: if len(user_msg) > fallback_chars:
return user_msg[:fallback_chars].rstrip() + "..." return user_msg[:fallback_chars].rstrip() + "..."
return user_msg if user_msg else "New Conversation" return user_msg if user_msg else "New Conversation"
@@ -118,43 +116,42 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
except Exception: except Exception:
parent = {} parent = {}
config = {**parent} config = {**parent}
config["run_name"] = "title_agent"
config["tags"] = [*(config.get("tags") or []), "middleware:title"] config["tags"] = [*(config.get("tags") or []), "middleware:title"]
return config return config
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None: def _generate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig) -> dict | None:
"""Generate a local fallback title without blocking on an LLM call.""" """Generate a local fallback title without blocking on an LLM call."""
if not self._should_generate_title(state): if not self._should_generate_title(state, title_config):
return None return None
_, user_msg = self._build_title_prompt(state) _, user_msg = self._build_title_prompt(state, title_config)
return {"title": self._fallback_title(user_msg)} return {"title": self._fallback_title(user_msg, title_config)}
async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None: async def _agenerate_title_result(self, state: TitleMiddlewareState, app_config: AppConfig) -> dict | None:
"""Generate a title asynchronously and fall back locally on failure.""" """Generate a title asynchronously and fall back locally on failure."""
if not self._should_generate_title(state): title_config = app_config.title
if not self._should_generate_title(state, title_config):
return None return None
config = get_title_config() prompt, user_msg = self._build_title_prompt(state, title_config)
prompt, user_msg = self._build_title_prompt(state)
try: try:
if config.model_name: if title_config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False) model = create_chat_model(name=title_config.model_name, thinking_enabled=False, app_config=app_config)
else: else:
model = create_chat_model(thinking_enabled=False) model = create_chat_model(thinking_enabled=False, app_config=app_config)
response = await model.ainvoke(prompt, config=self._get_runnable_config()) response = await model.ainvoke(prompt, config=self._get_runnable_config())
title = self._parse_title(response.content) title = self._parse_title(response.content, title_config)
if title: if title:
return {"title": title} return {"title": title}
except Exception: except Exception:
logger.debug("Failed to generate async title; falling back to local title", exc_info=True) logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
return {"title": self._fallback_title(user_msg)} return {"title": self._fallback_title(user_msg, title_config)}
@override @override
def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None: def after_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
return self._generate_title_result(state) return self._generate_title_result(state, runtime.context.app_config.title)
@override @override
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None: async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
return await self._agenerate_title_result(state) return await self._agenerate_title_result(state, runtime.context.app_config)
@@ -1,8 +1,10 @@
"""Tool error handling middleware and shared runtime middleware builders.""" """Tool error handling middleware and shared runtime middleware builders."""
from __future__ import annotations
import logging import logging
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import override from typing import TYPE_CHECKING, override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
@@ -11,6 +13,9 @@ from langgraph.errors import GraphBubbleUp
from langgraph.prebuilt.tool_node import ToolCallRequest from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command from langgraph.types import Command
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_MISSING_TOOL_CALL_ID = "missing_tool_call_id" _MISSING_TOOL_CALL_ID = "missing_tool_call_id"
@@ -67,6 +72,7 @@ class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
def _build_runtime_middlewares( def _build_runtime_middlewares(
*, *,
app_config: "AppConfig",
include_uploads: bool, include_uploads: bool,
include_dangling_tool_call_patch: bool, include_dangling_tool_call_patch: bool,
lazy_init: bool = True, lazy_init: bool = True,
@@ -94,9 +100,7 @@ def _build_runtime_middlewares(
middlewares.append(LLMErrorHandlingMiddleware()) middlewares.append(LLMErrorHandlingMiddleware())
# Guardrail middleware (if configured) # Guardrail middleware (if configured)
from deerflow.config.guardrails_config import get_guardrails_config guardrails_config = app_config.guardrails
guardrails_config = get_guardrails_config()
if guardrails_config.enabled and guardrails_config.provider: if guardrails_config.enabled and guardrails_config.provider:
import inspect import inspect
@@ -125,9 +129,10 @@ def _build_runtime_middlewares(
return middlewares return middlewares
def build_lead_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]: def build_lead_runtime_middlewares(*, app_config: "AppConfig", lazy_init: bool = True) -> list[AgentMiddleware]:
"""Middlewares shared by lead agent runtime before lead-only middlewares.""" """Middlewares shared by lead agent runtime before lead-only middlewares."""
return _build_runtime_middlewares( return _build_runtime_middlewares(
app_config=app_config,
include_uploads=True, include_uploads=True,
include_dangling_tool_call_patch=True, include_dangling_tool_call_patch=True,
lazy_init=lazy_init, lazy_init=lazy_init,
@@ -9,6 +9,7 @@ from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.paths import Paths, get_paths from deerflow.config.paths import Paths, get_paths
from deerflow.runtime.user_context import get_effective_user_id from deerflow.runtime.user_context import get_effective_user_id
from deerflow.utils.file_conversion import extract_outline from deerflow.utils.file_conversion import extract_outline
@@ -185,7 +186,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
return files if files else None return files if files else None
@override @override
def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime) -> dict | None: def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
"""Inject uploaded files information before agent execution. """Inject uploaded files information before agent execution.
New files come from the current message's additional_kwargs.files. New files come from the current message's additional_kwargs.files.
@@ -214,14 +215,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
return None return None
# Resolve uploads directory for existence checks # Resolve uploads directory for existence checks
thread_id = (runtime.context or {}).get("thread_id") thread_id = runtime.context.thread_id
if thread_id is None:
try:
from langgraph.config import get_config
thread_id = get_config().get("configurable", {}).get("thread_id")
except RuntimeError:
pass # get_config() raises outside a runnable context (e.g. unit tests)
uploads_dir = self._paths.sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) if thread_id else None uploads_dir = self._paths.sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) if thread_id else None
# Get newly uploaded files from the current message's additional_kwargs.files # Get newly uploaded files from the current message's additional_kwargs.files
+67 -40
View File
@@ -36,8 +36,9 @@ from deerflow.agents.lead_agent.agent import _build_middlewares
from deerflow.agents.lead_agent.prompt import apply_prompt_template from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.thread_state import ThreadState from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import AGENT_NAME_PATTERN from deerflow.config.agents_config import AGENT_NAME_PATTERN
from deerflow.config.app_config import get_app_config, reload_app_config from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
from deerflow.runtime.user_context import get_effective_user_id from deerflow.runtime.user_context import get_effective_user_id
@@ -116,6 +117,7 @@ class DeerFlowClient:
config_path: str | None = None, config_path: str | None = None,
checkpointer=None, checkpointer=None,
*, *,
config: AppConfig | None = None,
model_name: str | None = None, model_name: str | None = None,
thinking_enabled: bool = True, thinking_enabled: bool = True,
subagent_enabled: bool = False, subagent_enabled: bool = False,
@@ -130,9 +132,14 @@ class DeerFlowClient:
Args: Args:
config_path: Path to config.yaml. Uses default resolution if None. config_path: Path to config.yaml. Uses default resolution if None.
Ignored when ``config`` is provided.
checkpointer: LangGraph checkpointer instance for state persistence. checkpointer: LangGraph checkpointer instance for state persistence.
Required for multi-turn conversations on the same thread_id. Required for multi-turn conversations on the same thread_id.
Without a checkpointer, each call is stateless. Without a checkpointer, each call is stateless.
config: Optional pre-constructed AppConfig. When provided, it takes
precedence over ``config_path`` and no file is read. Enables
multi-client isolation: two clients with different configs can
coexist in the same process without touching process-global state.
model_name: Override the default model name from config. model_name: Override the default model name from config.
thinking_enabled: Enable model's extended thinking. thinking_enabled: Enable model's extended thinking.
subagent_enabled: Enable subagent delegation. subagent_enabled: Enable subagent delegation.
@@ -141,9 +148,18 @@ class DeerFlowClient:
available_skills: Optional set of skill names to make available. If None (default), all scanned skills are available. available_skills: Optional set of skill names to make available. If None (default), all scanned skills are available.
middlewares: Optional list of custom middlewares to inject into the agent. middlewares: Optional list of custom middlewares to inject into the agent.
""" """
if config_path is not None: # Constructor-captured config: the client owns its AppConfig for its lifetime.
reload_app_config(config_path) # Multiple clients with different configs do not contend.
self._app_config = get_app_config() #
# Priority: explicit ``config=`` > explicit ``config_path=`` > ``AppConfig.from_file()``
# with default path resolution. There is no ambient global fallback; if
# config.yaml cannot be located, ``from_file`` raises loudly.
if config is not None:
self._app_config = config
elif config_path is not None:
self._app_config = AppConfig.from_file(config_path)
else:
self._app_config = AppConfig.from_file()
if agent_name is not None and not AGENT_NAME_PATTERN.match(agent_name): if agent_name is not None and not AGENT_NAME_PATTERN.match(agent_name):
raise ValueError(f"Invalid agent name '{agent_name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}") raise ValueError(f"Invalid agent name '{agent_name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
@@ -171,6 +187,15 @@ class DeerFlowClient:
self._agent = None self._agent = None
self._agent_config_key = None self._agent_config_key = None
def _reload_config(self) -> None:
"""Reload config from file and refresh the cached reference.
Only the client's own ``_app_config`` is rebuilt. Other clients
and the process-global are untouched, so multi-client coexistence
survives reload.
"""
self._app_config = AppConfig.from_file()
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Internal helpers # Internal helpers
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@@ -228,10 +253,11 @@ class DeerFlowClient:
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3) max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
kwargs: dict[str, Any] = { kwargs: dict[str, Any] = {
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled), "model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=self._app_config),
"tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled), "tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled),
"middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares), "middleware": _build_middlewares(self._app_config, config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares),
"system_prompt": apply_prompt_template( "system_prompt": apply_prompt_template(
self._app_config,
subagent_enabled=subagent_enabled, subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents, max_concurrent_subagents=max_concurrent_subagents,
agent_name=self._agent_name, agent_name=self._agent_name,
@@ -243,7 +269,7 @@ class DeerFlowClient:
if checkpointer is None: if checkpointer is None:
from deerflow.runtime.checkpointer import get_checkpointer from deerflow.runtime.checkpointer import get_checkpointer
checkpointer = get_checkpointer() checkpointer = get_checkpointer(self._app_config)
if checkpointer is not None: if checkpointer is not None:
kwargs["checkpointer"] = checkpointer kwargs["checkpointer"] = checkpointer
@@ -251,12 +277,11 @@ class DeerFlowClient:
self._agent_config_key = key self._agent_config_key = key
logger.info("Agent created: agent_name=%s, model=%s, thinking=%s", self._agent_name, model_name, thinking_enabled) logger.info("Agent created: agent_name=%s, model=%s, thinking=%s", self._agent_name, model_name, thinking_enabled)
@staticmethod def _get_tools(self, *, model_name: str | None, subagent_enabled: bool):
def _get_tools(*, model_name: str | None, subagent_enabled: bool):
"""Lazy import to avoid circular dependency at module level.""" """Lazy import to avoid circular dependency at module level."""
from deerflow.tools import get_available_tools from deerflow.tools import get_available_tools
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=self._app_config)
@staticmethod @staticmethod
def _serialize_tool_calls(tool_calls) -> list[dict]: def _serialize_tool_calls(tool_calls) -> list[dict]:
@@ -377,7 +402,7 @@ class DeerFlowClient:
if checkpointer is None: if checkpointer is None:
from deerflow.runtime.checkpointer.provider import get_checkpointer from deerflow.runtime.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer() checkpointer = get_checkpointer(self._app_config)
thread_info_map = {} thread_info_map = {}
@@ -432,7 +457,7 @@ class DeerFlowClient:
if checkpointer is None: if checkpointer is None:
from deerflow.runtime.checkpointer.provider import get_checkpointer from deerflow.runtime.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer() checkpointer = get_checkpointer(self._app_config)
config = {"configurable": {"thread_id": thread_id}} config = {"configurable": {"thread_id": thread_id}}
checkpoints = [] checkpoints = []
@@ -552,9 +577,7 @@ class DeerFlowClient:
self._ensure_agent(config) self._ensure_agent(config)
state: dict[str, Any] = {"messages": [HumanMessage(content=message)]} state: dict[str, Any] = {"messages": [HumanMessage(content=message)]}
context = {"thread_id": thread_id} context = DeerFlowContext(app_config=self._app_config, thread_id=thread_id, agent_name=self._agent_name)
if self._agent_name:
context["agent_name"] = self._agent_name
seen_ids: set[str] = set() seen_ids: set[str] = set()
# Cross-mode handoff: ids already streamed via LangGraph ``messages`` # Cross-mode handoff: ids already streamed via LangGraph ``messages``
@@ -763,7 +786,7 @@ class DeerFlowClient:
"category": s.category, "category": s.category,
"enabled": s.enabled, "enabled": s.enabled,
} }
for s in load_skills(enabled_only=enabled_only) for s in load_skills(self._app_config, enabled_only=enabled_only)
] ]
} }
@@ -775,19 +798,19 @@ class DeerFlowClient:
""" """
from deerflow.agents.memory.updater import get_memory_data from deerflow.agents.memory.updater import get_memory_data
return get_memory_data(user_id=get_effective_user_id()) return get_memory_data(self._app_config.memory, user_id=get_effective_user_id())
def export_memory(self) -> dict: def export_memory(self) -> dict:
"""Export current memory data for backup or transfer.""" """Export current memory data for backup or transfer."""
from deerflow.agents.memory.updater import get_memory_data from deerflow.agents.memory.updater import get_memory_data
return get_memory_data(user_id=get_effective_user_id()) return get_memory_data(self._app_config.memory, user_id=get_effective_user_id())
def import_memory(self, memory_data: dict) -> dict: def import_memory(self, memory_data: dict) -> dict:
"""Import and persist full memory data.""" """Import and persist full memory data."""
from deerflow.agents.memory.updater import import_memory_data from deerflow.agents.memory.updater import import_memory_data
return import_memory_data(memory_data, user_id=get_effective_user_id()) return import_memory_data(self._app_config.memory, memory_data, user_id=get_effective_user_id())
def get_model(self, name: str) -> dict | None: def get_model(self, name: str) -> dict | None:
"""Get a specific model's configuration by name. """Get a specific model's configuration by name.
@@ -822,8 +845,8 @@ class DeerFlowClient:
Dict with "mcp_servers" key mapping server name to config, Dict with "mcp_servers" key mapping server name to config,
matching the Gateway API ``McpConfigResponse`` schema. matching the Gateway API ``McpConfigResponse`` schema.
""" """
config = get_extensions_config() ext = self._app_config.extensions
return {"mcp_servers": {name: server.model_dump() for name, server in config.mcp_servers.items()}} return {"mcp_servers": {name: server.model_dump() for name, server in ext.mcp_servers.items()}}
def update_mcp_config(self, mcp_servers: dict[str, dict]) -> dict: def update_mcp_config(self, mcp_servers: dict[str, dict]) -> dict:
"""Update MCP server configurations. """Update MCP server configurations.
@@ -845,18 +868,19 @@ class DeerFlowClient:
if config_path is None: if config_path is None:
raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.") raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.")
current_config = get_extensions_config() current_ext = self._app_config.extensions
config_data = { config_data = {
"mcpServers": mcp_servers, "mcpServers": mcp_servers,
"skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()}, "skills": {name: {"enabled": skill.enabled} for name, skill in current_ext.skills.items()},
} }
self._atomic_write_json(config_path, config_data) self._atomic_write_json(config_path, config_data)
self._agent = None self._agent = None
self._agent_config_key = None self._agent_config_key = None
reloaded = reload_extensions_config() self._reload_config()
reloaded = self._app_config.extensions
return {"mcp_servers": {name: server.model_dump() for name, server in reloaded.mcp_servers.items()}} return {"mcp_servers": {name: server.model_dump() for name, server in reloaded.mcp_servers.items()}}
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@@ -874,7 +898,7 @@ class DeerFlowClient:
""" """
from deerflow.skills.loader import load_skills from deerflow.skills.loader import load_skills
skill = next((s for s in load_skills(enabled_only=False) if s.name == name), None) skill = next((s for s in load_skills(self._app_config, enabled_only=False) if s.name == name), None)
if skill is None: if skill is None:
return None return None
return { return {
@@ -901,7 +925,7 @@ class DeerFlowClient:
""" """
from deerflow.skills.loader import load_skills from deerflow.skills.loader import load_skills
skills = load_skills(enabled_only=False) skills = load_skills(self._app_config, enabled_only=False)
skill = next((s for s in skills if s.name == name), None) skill = next((s for s in skills if s.name == name), None)
if skill is None: if skill is None:
raise ValueError(f"Skill '{name}' not found") raise ValueError(f"Skill '{name}' not found")
@@ -910,21 +934,25 @@ class DeerFlowClient:
if config_path is None: if config_path is None:
raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.") raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.")
extensions_config = get_extensions_config() # Do not mutate self._app_config (frozen value). Compose the new
extensions_config.skills[name] = SkillStateConfig(enabled=enabled) # skills state in a fresh dict, write it to disk, and let _reload_config()
# below rebuild AppConfig from the updated file.
ext = self._app_config.extensions
new_skills = {n: {"enabled": sc.enabled} for n, sc in ext.skills.items()}
new_skills[name] = {"enabled": enabled}
config_data = { config_data = {
"mcpServers": {n: s.model_dump() for n, s in extensions_config.mcp_servers.items()}, "mcpServers": {n: s.model_dump() for n, s in ext.mcp_servers.items()},
"skills": {n: {"enabled": sc.enabled} for n, sc in extensions_config.skills.items()}, "skills": new_skills,
} }
self._atomic_write_json(config_path, config_data) self._atomic_write_json(config_path, config_data)
self._agent = None self._agent = None
self._agent_config_key = None self._agent_config_key = None
reload_extensions_config() self._reload_config()
updated = next((s for s in load_skills(enabled_only=False) if s.name == name), None) updated = next((s for s in load_skills(self._app_config, enabled_only=False) if s.name == name), None)
if updated is None: if updated is None:
raise RuntimeError(f"Skill '{name}' disappeared after update") raise RuntimeError(f"Skill '{name}' disappeared after update")
return { return {
@@ -962,25 +990,25 @@ class DeerFlowClient:
""" """
from deerflow.agents.memory.updater import reload_memory_data from deerflow.agents.memory.updater import reload_memory_data
return reload_memory_data(user_id=get_effective_user_id()) return reload_memory_data(self._app_config.memory, user_id=get_effective_user_id())
def clear_memory(self) -> dict: def clear_memory(self) -> dict:
"""Clear all persisted memory data.""" """Clear all persisted memory data."""
from deerflow.agents.memory.updater import clear_memory_data from deerflow.agents.memory.updater import clear_memory_data
return clear_memory_data(user_id=get_effective_user_id()) return clear_memory_data(self._app_config.memory, user_id=get_effective_user_id())
def create_memory_fact(self, content: str, category: str = "context", confidence: float = 0.5) -> dict: def create_memory_fact(self, content: str, category: str = "context", confidence: float = 0.5) -> dict:
"""Create a single fact manually.""" """Create a single fact manually."""
from deerflow.agents.memory.updater import create_memory_fact from deerflow.agents.memory.updater import create_memory_fact
return create_memory_fact(content=content, category=category, confidence=confidence) return create_memory_fact(self._app_config.memory, content=content, category=category, confidence=confidence)
def delete_memory_fact(self, fact_id: str) -> dict: def delete_memory_fact(self, fact_id: str) -> dict:
"""Delete a single fact from memory by fact id.""" """Delete a single fact from memory by fact id."""
from deerflow.agents.memory.updater import delete_memory_fact from deerflow.agents.memory.updater import delete_memory_fact
return delete_memory_fact(fact_id) return delete_memory_fact(self._app_config.memory, fact_id)
def update_memory_fact( def update_memory_fact(
self, self,
@@ -993,6 +1021,7 @@ class DeerFlowClient:
from deerflow.agents.memory.updater import update_memory_fact from deerflow.agents.memory.updater import update_memory_fact
return update_memory_fact( return update_memory_fact(
self._app_config.memory,
fact_id=fact_id, fact_id=fact_id,
content=content, content=content,
category=category, category=category,
@@ -1005,9 +1034,7 @@ class DeerFlowClient:
Returns: Returns:
Memory config dict. Memory config dict.
""" """
from deerflow.config.memory_config import get_memory_config config = self._app_config.memory
config = get_memory_config()
return { return {
"enabled": config.enabled, "enabled": config.enabled,
"storage_path": config.storage_path, "storage_path": config.storage_path,
@@ -25,7 +25,7 @@ except ImportError: # pragma: no cover - Windows fallback
fcntl = None # type: ignore[assignment] fcntl = None # type: ignore[assignment]
import msvcrt import msvcrt
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.runtime.user_context import get_effective_user_id from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox import Sandbox from deerflow.sandbox.sandbox import Sandbox
@@ -90,7 +90,8 @@ class AioSandboxProvider(SandboxProvider):
API_KEY: $MY_API_KEY API_KEY: $MY_API_KEY
""" """
def __init__(self): def __init__(self, app_config: "AppConfig"):
self._app_config = app_config
self._lock = threading.Lock() self._lock = threading.Lock()
self._sandboxes: dict[str, AioSandbox] = {} # sandbox_id -> AioSandbox instance self._sandboxes: dict[str, AioSandbox] = {} # sandbox_id -> AioSandbox instance
self._sandbox_infos: dict[str, SandboxInfo] = {} # sandbox_id -> SandboxInfo (for destroy) self._sandbox_infos: dict[str, SandboxInfo] = {} # sandbox_id -> SandboxInfo (for destroy)
@@ -159,8 +160,7 @@ class AioSandboxProvider(SandboxProvider):
def _load_config(self) -> dict: def _load_config(self) -> dict:
"""Load sandbox configuration from app config.""" """Load sandbox configuration from app config."""
config = get_app_config() sandbox_config = self._app_config.sandbox
sandbox_config = config.sandbox
idle_timeout = getattr(sandbox_config, "idle_timeout", None) idle_timeout = getattr(sandbox_config, "idle_timeout", None)
replicas = getattr(sandbox_config, "replicas", None) replicas = getattr(sandbox_config, "replicas", None)
@@ -283,17 +283,15 @@ class AioSandboxProvider(SandboxProvider):
(paths.host_acp_workspace_dir(thread_id, user_id=user_id), "/mnt/acp-workspace", True), (paths.host_acp_workspace_dir(thread_id, user_id=user_id), "/mnt/acp-workspace", True),
] ]
@staticmethod def _get_skills_mount(self) -> tuple[str, str, bool] | None:
def _get_skills_mount() -> tuple[str, str, bool] | None:
"""Get the skills directory mount configuration. """Get the skills directory mount configuration.
Mount source uses DEER_FLOW_HOST_SKILLS_PATH when running inside Docker (DooD) Mount source uses DEER_FLOW_HOST_SKILLS_PATH when running inside Docker (DooD)
so the host Docker daemon can resolve the path. so the host Docker daemon can resolve the path.
""" """
try: try:
config = get_app_config() skills_path = self._app_config.skills.get_skills_path()
skills_path = config.skills.get_skills_path() container_path = self._app_config.skills.container_path
container_path = config.skills.container_path
if skills_path.exists(): if skills_path.exists():
# When running inside Docker with DooD, use host-side skills path. # When running inside Docker with DooD, use host-side skills path.
@@ -5,9 +5,9 @@ Web Search Tool - Search the web using DuckDuckGo (no API key required).
import json import json
import logging import logging
from langchain.tools import tool from langchain.tools import ToolRuntime, tool
from deerflow.config import get_app_config from deerflow.config.deer_flow_context import resolve_context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -55,6 +55,7 @@ def _search_text(
@tool("web_search", parse_docstring=True) @tool("web_search", parse_docstring=True)
def web_search_tool( def web_search_tool(
query: str, query: str,
runtime: ToolRuntime,
max_results: int = 5, max_results: int = 5,
) -> str: ) -> str:
"""Search the web for information. Use this tool to find current information, news, articles, and facts from the internet. """Search the web for information. Use this tool to find current information, news, articles, and facts from the internet.
@@ -63,11 +64,11 @@ def web_search_tool(
query: Search keywords describing what you want to find. Be specific for better results. query: Search keywords describing what you want to find. Be specific for better results.
max_results: Maximum number of results to return. Default is 5. max_results: Maximum number of results to return. Default is 5.
""" """
config = get_app_config().get_tool_config("web_search") tool_config = resolve_context(runtime).app_config.get_tool_config("web_search")
# Override max_results from config if set # Override max_results from config if set
if config is not None and "max_results" in config.model_extra: if tool_config is not None and "max_results" in tool_config.model_extra:
max_results = config.model_extra.get("max_results", max_results) max_results = tool_config.model_extra.get("max_results", max_results)
results = _search_text( results = _search_text(
query=query, query=query,
@@ -1,37 +1,39 @@
import json import json
from exa_py import Exa from exa_py import Exa
from langchain.tools import tool from langchain.tools import ToolRuntime, tool
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import resolve_context
def _get_exa_client(tool_name: str = "web_search") -> Exa: def _get_exa_client(app_config: AppConfig, tool_name: str = "web_search") -> Exa:
config = get_app_config().get_tool_config(tool_name) tool_config = app_config.get_tool_config(tool_name)
api_key = None api_key = None
if config is not None and "api_key" in config.model_extra: if tool_config is not None and "api_key" in tool_config.model_extra:
api_key = config.model_extra.get("api_key") api_key = tool_config.model_extra.get("api_key")
return Exa(api_key=api_key) return Exa(api_key=api_key)
@tool("web_search", parse_docstring=True) @tool("web_search", parse_docstring=True)
def web_search_tool(query: str) -> str: def web_search_tool(query: str, runtime: ToolRuntime) -> str:
"""Search the web. """Search the web.
Args: Args:
query: The query to search for. query: The query to search for.
""" """
try: try:
config = get_app_config().get_tool_config("web_search") app_config = resolve_context(runtime).app_config
tool_config = app_config.get_tool_config("web_search")
max_results = 5 max_results = 5
search_type = "auto" search_type = "auto"
contents_max_characters = 1000 contents_max_characters = 1000
if config is not None: if tool_config is not None:
max_results = config.model_extra.get("max_results", max_results) max_results = tool_config.model_extra.get("max_results", max_results)
search_type = config.model_extra.get("search_type", search_type) search_type = tool_config.model_extra.get("search_type", search_type)
contents_max_characters = config.model_extra.get("contents_max_characters", contents_max_characters) contents_max_characters = tool_config.model_extra.get("contents_max_characters", contents_max_characters)
client = _get_exa_client() client = _get_exa_client(app_config)
res = client.search( res = client.search(
query, query,
type=search_type, type=search_type,
@@ -54,7 +56,7 @@ def web_search_tool(query: str) -> str:
@tool("web_fetch", parse_docstring=True) @tool("web_fetch", parse_docstring=True)
def web_fetch_tool(url: str) -> str: def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
"""Fetch the contents of a web page at a given URL. """Fetch the contents of a web page at a given URL.
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools. Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls. This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
@@ -65,7 +67,7 @@ def web_fetch_tool(url: str) -> str:
url: The URL to fetch the contents of. url: The URL to fetch the contents of.
""" """
try: try:
client = _get_exa_client("web_fetch") client = _get_exa_client(resolve_context(runtime).app_config, "web_fetch")
res = client.get_contents([url], text={"max_characters": 4096}) res = client.get_contents([url], text={"max_characters": 4096})
if res.results: if res.results:
@@ -1,33 +1,35 @@
import json import json
from firecrawl import FirecrawlApp from firecrawl import FirecrawlApp
from langchain.tools import tool from langchain.tools import ToolRuntime, tool
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import resolve_context
def _get_firecrawl_client(tool_name: str = "web_search") -> FirecrawlApp: def _get_firecrawl_client(app_config: AppConfig, tool_name: str = "web_search") -> FirecrawlApp:
config = get_app_config().get_tool_config(tool_name) tool_config = app_config.get_tool_config(tool_name)
api_key = None api_key = None
if config is not None and "api_key" in config.model_extra: if tool_config is not None and "api_key" in tool_config.model_extra:
api_key = config.model_extra.get("api_key") api_key = tool_config.model_extra.get("api_key")
return FirecrawlApp(api_key=api_key) # type: ignore[arg-type] return FirecrawlApp(api_key=api_key) # type: ignore[arg-type]
@tool("web_search", parse_docstring=True) @tool("web_search", parse_docstring=True)
def web_search_tool(query: str) -> str: def web_search_tool(query: str, runtime: ToolRuntime) -> str:
"""Search the web. """Search the web.
Args: Args:
query: The query to search for. query: The query to search for.
""" """
try: try:
config = get_app_config().get_tool_config("web_search") app_config = resolve_context(runtime).app_config
tool_config = app_config.get_tool_config("web_search")
max_results = 5 max_results = 5
if config is not None: if tool_config is not None:
max_results = config.model_extra.get("max_results", max_results) max_results = tool_config.model_extra.get("max_results", max_results)
client = _get_firecrawl_client("web_search") client = _get_firecrawl_client(app_config, "web_search")
result = client.search(query, limit=max_results) result = client.search(query, limit=max_results)
# result.web contains list of SearchResultWeb objects # result.web contains list of SearchResultWeb objects
@@ -47,7 +49,7 @@ def web_search_tool(query: str) -> str:
@tool("web_fetch", parse_docstring=True) @tool("web_fetch", parse_docstring=True)
def web_fetch_tool(url: str) -> str: def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
"""Fetch the contents of a web page at a given URL. """Fetch the contents of a web page at a given URL.
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools. Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls. This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
@@ -58,7 +60,8 @@ def web_fetch_tool(url: str) -> str:
url: The URL to fetch the contents of. url: The URL to fetch the contents of.
""" """
try: try:
client = _get_firecrawl_client("web_fetch") app_config = resolve_context(runtime).app_config
client = _get_firecrawl_client(app_config, "web_fetch")
result = client.scrape(url, formats=["markdown"]) result = client.scrape(url, formats=["markdown"])
markdown_content = result.markdown or "" markdown_content = result.markdown or ""
@@ -5,9 +5,9 @@ Image Search Tool - Search images using DuckDuckGo for reference in image genera
import json import json
import logging import logging
from langchain.tools import tool from langchain.tools import ToolRuntime, tool
from deerflow.config import get_app_config from deerflow.config.deer_flow_context import resolve_context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -77,6 +77,7 @@ def _search_images(
@tool("image_search", parse_docstring=True) @tool("image_search", parse_docstring=True)
def image_search_tool( def image_search_tool(
query: str, query: str,
runtime: ToolRuntime,
max_results: int = 5, max_results: int = 5,
size: str | None = None, size: str | None = None,
type_image: str | None = None, type_image: str | None = None,
@@ -99,11 +100,11 @@ def image_search_tool(
type_image: Image type filter. Options: "photo", "clipart", "gif", "transparent", "line". Use "photo" for realistic references. type_image: Image type filter. Options: "photo", "clipart", "gif", "transparent", "line". Use "photo" for realistic references.
layout: Layout filter. Options: "Square", "Tall", "Wide". Choose based on your generation needs. layout: Layout filter. Options: "Square", "Tall", "Wide". Choose based on your generation needs.
""" """
config = get_app_config().get_tool_config("image_search") tool_config = resolve_context(runtime).app_config.get_tool_config("image_search")
# Override max_results from config if set # Override max_results from config if set
if config is not None and "max_results" in config.model_extra: if tool_config is not None and "max_results" in tool_config.model_extra:
max_results = config.model_extra.get("max_results", max_results) max_results = tool_config.model_extra.get("max_results", max_results)
results = _search_images( results = _search_images(
query=query, query=query,
@@ -1,6 +1,7 @@
from langchain.tools import tool from langchain.tools import ToolRuntime, tool
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import resolve_context
from deerflow.utils.readability import ReadabilityExtractor from deerflow.utils.readability import ReadabilityExtractor
from .infoquest_client import InfoQuestClient from .infoquest_client import InfoQuestClient
@@ -8,13 +9,13 @@ from .infoquest_client import InfoQuestClient
readability_extractor = ReadabilityExtractor() readability_extractor = ReadabilityExtractor()
def _get_infoquest_client() -> InfoQuestClient: def _get_infoquest_client(app_config: AppConfig) -> InfoQuestClient:
search_config = get_app_config().get_tool_config("web_search") search_config = app_config.get_tool_config("web_search")
search_time_range = -1 search_time_range = -1
if search_config is not None and "search_time_range" in search_config.model_extra: if search_config is not None and "search_time_range" in search_config.model_extra:
search_time_range = search_config.model_extra.get("search_time_range") search_time_range = search_config.model_extra.get("search_time_range")
fetch_config = get_app_config().get_tool_config("web_fetch") fetch_config = app_config.get_tool_config("web_fetch")
fetch_time = -1 fetch_time = -1
if fetch_config is not None and "fetch_time" in fetch_config.model_extra: if fetch_config is not None and "fetch_time" in fetch_config.model_extra:
fetch_time = fetch_config.model_extra.get("fetch_time") fetch_time = fetch_config.model_extra.get("fetch_time")
@@ -25,7 +26,7 @@ def _get_infoquest_client() -> InfoQuestClient:
if fetch_config is not None and "navigation_timeout" in fetch_config.model_extra: if fetch_config is not None and "navigation_timeout" in fetch_config.model_extra:
navigation_timeout = fetch_config.model_extra.get("navigation_timeout") navigation_timeout = fetch_config.model_extra.get("navigation_timeout")
image_search_config = get_app_config().get_tool_config("image_search") image_search_config = app_config.get_tool_config("image_search")
image_search_time_range = -1 image_search_time_range = -1
if image_search_config is not None and "image_search_time_range" in image_search_config.model_extra: if image_search_config is not None and "image_search_time_range" in image_search_config.model_extra:
image_search_time_range = image_search_config.model_extra.get("image_search_time_range") image_search_time_range = image_search_config.model_extra.get("image_search_time_range")
@@ -44,19 +45,18 @@ def _get_infoquest_client() -> InfoQuestClient:
@tool("web_search", parse_docstring=True) @tool("web_search", parse_docstring=True)
def web_search_tool(query: str) -> str: def web_search_tool(query: str, runtime: ToolRuntime) -> str:
"""Search the web. """Search the web.
Args: Args:
query: The query to search for. query: The query to search for.
""" """
client = _get_infoquest_client(resolve_context(runtime).app_config)
client = _get_infoquest_client()
return client.web_search(query) return client.web_search(query)
@tool("web_fetch", parse_docstring=True) @tool("web_fetch", parse_docstring=True)
def web_fetch_tool(url: str) -> str: def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
"""Fetch the contents of a web page at a given URL. """Fetch the contents of a web page at a given URL.
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools. Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls. This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
@@ -66,7 +66,7 @@ def web_fetch_tool(url: str) -> str:
Args: Args:
url: The URL to fetch the contents of. url: The URL to fetch the contents of.
""" """
client = _get_infoquest_client() client = _get_infoquest_client(resolve_context(runtime).app_config)
result = client.fetch(url) result = client.fetch(url)
if result.startswith("Error: "): if result.startswith("Error: "):
return result return result
@@ -75,7 +75,7 @@ def web_fetch_tool(url: str) -> str:
@tool("image_search", parse_docstring=True) @tool("image_search", parse_docstring=True)
def image_search_tool(query: str) -> str: def image_search_tool(query: str, runtime: ToolRuntime) -> str:
"""Search for images online. Use this tool BEFORE image generation to find reference images for characters, portraits, objects, scenes, or any content requiring visual accuracy. """Search for images online. Use this tool BEFORE image generation to find reference images for characters, portraits, objects, scenes, or any content requiring visual accuracy.
**When to use:** **When to use:**
@@ -89,5 +89,5 @@ def image_search_tool(query: str) -> str:
Args: Args:
query: The query to search for images. query: The query to search for images.
""" """
client = _get_infoquest_client() client = _get_infoquest_client(resolve_context(runtime).app_config)
return client.image_search(query) return client.image_search(query)
@@ -1,16 +1,16 @@
import asyncio import asyncio
from langchain.tools import tool from langchain.tools import ToolRuntime, tool
from deerflow.community.jina_ai.jina_client import JinaClient from deerflow.community.jina_ai.jina_client import JinaClient
from deerflow.config import get_app_config from deerflow.config.deer_flow_context import resolve_context
from deerflow.utils.readability import ReadabilityExtractor from deerflow.utils.readability import ReadabilityExtractor
readability_extractor = ReadabilityExtractor() readability_extractor = ReadabilityExtractor()
@tool("web_fetch", parse_docstring=True) @tool("web_fetch", parse_docstring=True)
async def web_fetch_tool(url: str) -> str: async def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
"""Fetch the contents of a web page at a given URL. """Fetch the contents of a web page at a given URL.
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools. Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls. This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
@@ -22,9 +22,9 @@ async def web_fetch_tool(url: str) -> str:
""" """
jina_client = JinaClient() jina_client = JinaClient()
timeout = 10 timeout = 10
config = get_app_config().get_tool_config("web_fetch") tool_config = resolve_context(runtime).app_config.get_tool_config("web_fetch")
if config is not None and "timeout" in config.model_extra: if tool_config is not None and "timeout" in tool_config.model_extra:
timeout = config.model_extra.get("timeout") timeout = tool_config.model_extra.get("timeout")
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout) html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)
if isinstance(html_content, str) and html_content.startswith("Error:"): if isinstance(html_content, str) and html_content.startswith("Error:"):
return html_content return html_content
@@ -1,32 +1,34 @@
import json import json
from langchain.tools import tool from langchain.tools import ToolRuntime, tool
from tavily import TavilyClient from tavily import TavilyClient
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import resolve_context
def _get_tavily_client() -> TavilyClient: def _get_tavily_client(app_config: AppConfig) -> TavilyClient:
config = get_app_config().get_tool_config("web_search") tool_config = app_config.get_tool_config("web_search")
api_key = None api_key = None
if config is not None and "api_key" in config.model_extra: if tool_config is not None and "api_key" in tool_config.model_extra:
api_key = config.model_extra.get("api_key") api_key = tool_config.model_extra.get("api_key")
return TavilyClient(api_key=api_key) return TavilyClient(api_key=api_key)
@tool("web_search", parse_docstring=True) @tool("web_search", parse_docstring=True)
def web_search_tool(query: str) -> str: def web_search_tool(query: str, runtime: ToolRuntime) -> str:
"""Search the web. """Search the web.
Args: Args:
query: The query to search for. query: The query to search for.
""" """
config = get_app_config().get_tool_config("web_search") app_config = resolve_context(runtime).app_config
tool_config = app_config.get_tool_config("web_search")
max_results = 5 max_results = 5
if config is not None and "max_results" in config.model_extra: if tool_config is not None and "max_results" in tool_config.model_extra:
max_results = config.model_extra.get("max_results") max_results = tool_config.model_extra.get("max_results")
client = _get_tavily_client() client = _get_tavily_client(app_config)
res = client.search(query, max_results=max_results) res = client.search(query, max_results=max_results)
normalized_results = [ normalized_results = [
{ {
@@ -41,7 +43,7 @@ def web_search_tool(query: str) -> str:
@tool("web_fetch", parse_docstring=True) @tool("web_fetch", parse_docstring=True)
def web_fetch_tool(url: str) -> str: def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
"""Fetch the contents of a web page at a given URL. """Fetch the contents of a web page at a given URL.
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools. Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls. This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
@@ -51,7 +53,8 @@ def web_fetch_tool(url: str) -> str:
Args: Args:
url: The URL to fetch the contents of. url: The URL to fetch the contents of.
""" """
client = _get_tavily_client() app_config = resolve_context(runtime).app_config
client = _get_tavily_client(app_config)
res = client.extract([url]) res = client.extract([url])
if "failed_results" in res and len(res["failed_results"]) > 0: if "failed_results" in res and len(res["failed_results"]) > 0:
return f"Error: {res['failed_results'][0]['error']}" return f"Error: {res['failed_results'][0]['error']}"
@@ -1,6 +1,6 @@
from .app_config import get_app_config from .app_config import AppConfig
from .extensions_config import ExtensionsConfig, get_extensions_config from .extensions_config import ExtensionsConfig
from .memory_config import MemoryConfig, get_memory_config from .memory_config import MemoryConfig
from .paths import Paths, get_paths from .paths import Paths, get_paths
from .skill_evolution_config import SkillEvolutionConfig from .skill_evolution_config import SkillEvolutionConfig
from .skills_config import SkillsConfig from .skills_config import SkillsConfig
@@ -13,18 +13,16 @@ from .tracing_config import (
) )
__all__ = [ __all__ = [
"get_app_config", "AppConfig",
"SkillEvolutionConfig",
"Paths",
"get_paths",
"SkillsConfig",
"ExtensionsConfig", "ExtensionsConfig",
"get_extensions_config",
"MemoryConfig", "MemoryConfig",
"get_memory_config", "Paths",
"get_tracing_config", "SkillEvolutionConfig",
"get_explicitly_enabled_tracing_providers", "SkillsConfig",
"get_enabled_tracing_providers", "get_enabled_tracing_providers",
"get_explicitly_enabled_tracing_providers",
"get_paths",
"get_tracing_config",
"is_tracing_enabled", "is_tracing_enabled",
"validate_enabled_tracing_providers", "validate_enabled_tracing_providers",
] ]
@@ -1,16 +1,13 @@
"""ACP (Agent Client Protocol) agent configuration loaded from config.yaml.""" """ACP (Agent Client Protocol) agent configuration loaded from config.yaml."""
import logging from pydantic import BaseModel, ConfigDict, Field
from collections.abc import Mapping
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class ACPAgentConfig(BaseModel): class ACPAgentConfig(BaseModel):
"""Configuration for a single ACP-compatible agent.""" """Configuration for a single ACP-compatible agent."""
model_config = ConfigDict(frozen=True)
command: str = Field(description="Command to launch the ACP agent subprocess") command: str = Field(description="Command to launch the ACP agent subprocess")
args: list[str] = Field(default_factory=list, description="Additional command arguments") args: list[str] = Field(default_factory=list, description="Additional command arguments")
env: dict[str, str] = Field(default_factory=dict, description="Environment variables to inject into the agent subprocess. Values starting with $ are resolved from host environment variables.") env: dict[str, str] = Field(default_factory=dict, description="Environment variables to inject into the agent subprocess. Values starting with $ are resolved from host environment variables.")
@@ -24,28 +21,3 @@ class ACPAgentConfig(BaseModel):
"are denied — the agent must be configured to operate without requesting permissions." "are denied — the agent must be configured to operate without requesting permissions."
), ),
) )
_acp_agents: dict[str, ACPAgentConfig] = {}
def get_acp_agents() -> dict[str, ACPAgentConfig]:
"""Get the currently configured ACP agents.
Returns:
Mapping of agent name -> ACPAgentConfig. Empty dict if no ACP agents are configured.
"""
return _acp_agents
def load_acp_config_from_dict(config_dict: Mapping[str, Mapping[str, object]] | None) -> None:
"""Load ACP agent configuration from a dictionary (typically from config.yaml).
Args:
config_dict: Mapping of agent name -> config fields.
"""
global _acp_agents
if config_dict is None:
config_dict = {}
_acp_agents = {name: ACPAgentConfig(**cfg) for name, cfg in config_dict.items()}
logger.info("ACP config loaded: %d agent(s): %s", len(_acp_agents), list(_acp_agents.keys()))
@@ -1,32 +1,14 @@
"""Configuration for the custom agents management API.""" """Configuration for the custom agents management API."""
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
class AgentsApiConfig(BaseModel): class AgentsApiConfig(BaseModel):
"""Configuration for custom-agent and user-profile management routes.""" """Configuration for custom-agent and user-profile management routes."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field( enabled: bool = Field(
default=False, default=False,
description=("Whether to expose the custom-agent management API over HTTP. When disabled, the gateway rejects read/write access to custom agent SOUL.md, config, and USER.md prompt-management routes."), description=("Whether to expose the custom-agent management API over HTTP. When disabled, the gateway rejects read/write access to custom agent SOUL.md, config, and USER.md prompt-management routes."),
) )
_agents_api_config: AgentsApiConfig = AgentsApiConfig()
def get_agents_api_config() -> AgentsApiConfig:
"""Get the current agents API configuration."""
return _agents_api_config
def set_agents_api_config(config: AgentsApiConfig) -> None:
"""Set the agents API configuration."""
global _agents_api_config
_agents_api_config = config
def load_agents_api_config_from_dict(config_dict: dict) -> None:
"""Load agents API configuration from a dictionary."""
global _agents_api_config
_agents_api_config = AgentsApiConfig(**config_dict)
@@ -5,7 +5,7 @@ import re
from typing import Any from typing import Any
import yaml import yaml
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from deerflow.config.paths import get_paths from deerflow.config.paths import get_paths
@@ -29,6 +29,8 @@ def validate_agent_name(name: str | None) -> str | None:
class AgentConfig(BaseModel): class AgentConfig(BaseModel):
"""Configuration for a custom agent.""" """Configuration for a custom agent."""
model_config = ConfigDict(frozen=True)
name: str name: str
description: str = "" description: str = ""
model: str | None = None model: str | None = None
@@ -1,6 +1,7 @@
from __future__ import annotations
import logging import logging
import os import os
from contextvars import ContextVar
from pathlib import Path from pathlib import Path
from typing import Any, Self from typing import Any, Self
@@ -8,25 +9,25 @@ import yaml
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from deerflow.config.acp_config import load_acp_config_from_dict from deerflow.config.acp_config import ACPAgentConfig
from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict from deerflow.config.agents_api_config import AgentsApiConfig
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.config.database_config import DatabaseConfig from deerflow.config.database_config import DatabaseConfig
from deerflow.config.extensions_config import ExtensionsConfig from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict from deerflow.config.guardrails_config import GuardrailsConfig
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict from deerflow.config.memory_config import MemoryConfig
from deerflow.config.model_config import ModelConfig from deerflow.config.model_config import ModelConfig
from deerflow.config.run_events_config import RunEventsConfig from deerflow.config.run_events_config import RunEventsConfig
from deerflow.config.sandbox_config import SandboxConfig from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.skill_evolution_config import SkillEvolutionConfig from deerflow.config.skill_evolution_config import SkillEvolutionConfig
from deerflow.config.skills_config import SkillsConfig from deerflow.config.skills_config import SkillsConfig
from deerflow.config.stream_bridge_config import StreamBridgeConfig, load_stream_bridge_config_from_dict from deerflow.config.stream_bridge_config import StreamBridgeConfig
from deerflow.config.subagents_config import SubagentsAppConfig, load_subagents_config_from_dict from deerflow.config.subagents_config import SubagentsAppConfig
from deerflow.config.summarization_config import SummarizationConfig, load_summarization_config_from_dict from deerflow.config.summarization_config import SummarizationConfig
from deerflow.config.title_config import TitleConfig, load_title_config_from_dict from deerflow.config.title_config import TitleConfig
from deerflow.config.token_usage_config import TokenUsageConfig from deerflow.config.token_usage_config import TokenUsageConfig
from deerflow.config.tool_config import ToolConfig, ToolGroupConfig from deerflow.config.tool_config import ToolConfig, ToolGroupConfig
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict from deerflow.config.tool_search_config import ToolSearchConfig
load_dotenv() load_dotenv()
@@ -73,11 +74,12 @@ class AppConfig(BaseModel):
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration") subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration") guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration") circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
model_config = ConfigDict(extra="allow", frozen=False)
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration") database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration") run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
model_config = ConfigDict(extra="allow", frozen=True)
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration") checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration") stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
acp_agents: dict[str, ACPAgentConfig] = Field(default_factory=dict, description="ACP agent configurations keyed by agent name")
@classmethod @classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path: def resolve_config_path(cls, config_path: str | None = None) -> Path:
@@ -126,49 +128,6 @@ class AppConfig(BaseModel):
config_data = cls.resolve_env_variables(config_data) config_data = cls.resolve_env_variables(config_data)
cls._apply_database_defaults(config_data) cls._apply_database_defaults(config_data)
# Load title config if present
if "title" in config_data:
load_title_config_from_dict(config_data["title"])
# Load summarization config if present
if "summarization" in config_data:
load_summarization_config_from_dict(config_data["summarization"])
# Load memory config if present
if "memory" in config_data:
load_memory_config_from_dict(config_data["memory"])
# Always refresh agents API config so removed config sections reset
# singleton-backed state to its default/disabled values on reload.
load_agents_api_config_from_dict(config_data.get("agents_api") or {})
# Load subagents config if present
if "subagents" in config_data:
load_subagents_config_from_dict(config_data["subagents"])
# Load tool_search config if present
if "tool_search" in config_data:
load_tool_search_config_from_dict(config_data["tool_search"])
# Load guardrails config if present
if "guardrails" in config_data:
load_guardrails_config_from_dict(config_data["guardrails"])
# Load circuit_breaker config if present
if "circuit_breaker" in config_data:
config_data["circuit_breaker"] = config_data["circuit_breaker"]
# Load checkpointer config if present
if "checkpointer" in config_data:
load_checkpointer_config_from_dict(config_data["checkpointer"])
# Load stream bridge config if present
if "stream_bridge" in config_data:
load_stream_bridge_config_from_dict(config_data["stream_bridge"])
# Always refresh ACP agent config so removed entries do not linger across reloads.
load_acp_config_from_dict(config_data.get("acp_agents", {}))
# Load extensions config separately (it's in a different file) # Load extensions config separately (it's in a different file)
extensions_config = ExtensionsConfig.from_file() extensions_config = ExtensionsConfig.from_file()
config_data["extensions"] = extensions_config.model_dump() config_data["extensions"] = extensions_config.model_dump()
@@ -291,130 +250,8 @@ class AppConfig(BaseModel):
""" """
return next((group for group in self.tool_groups if group.name == name), None) return next((group for group in self.tool_groups if group.name == name), None)
# AppConfig is a pure value object: construct with ``from_file()``, pass around.
_app_config: AppConfig | None = None # Composition roots that hold the resolved instance:
_app_config_path: Path | None = None # - Gateway: ``app.state.config`` via ``Depends(get_config)``
_app_config_mtime: float | None = None # - Client: ``DeerFlowClient._app_config``
_app_config_is_custom = False # - Agent run: ``Runtime[DeerFlowContext].context.app_config``
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
def _get_config_mtime(config_path: Path) -> float | None:
"""Get the modification time of a config file if it exists."""
try:
return config_path.stat().st_mtime
except OSError:
return None
def _load_and_cache_app_config(config_path: str | None = None) -> AppConfig:
"""Load config from disk and refresh cache metadata."""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
resolved_path = AppConfig.resolve_config_path(config_path)
_app_config = AppConfig.from_file(str(resolved_path))
_app_config_path = resolved_path
_app_config_mtime = _get_config_mtime(resolved_path)
_app_config_is_custom = False
return _app_config
def get_app_config() -> AppConfig:
"""Get the DeerFlow config instance.
Returns a cached singleton instance and automatically reloads it when the
underlying config file path or modification time changes. Use
`reload_app_config()` to force a reload, or `reset_app_config()` to clear
the cache.
"""
global _app_config, _app_config_path, _app_config_mtime
runtime_override = _current_app_config.get()
if runtime_override is not None:
return runtime_override
if _app_config is not None and _app_config_is_custom:
return _app_config
resolved_path = AppConfig.resolve_config_path()
current_mtime = _get_config_mtime(resolved_path)
should_reload = _app_config is None or _app_config_path != resolved_path or _app_config_mtime != current_mtime
if should_reload:
if _app_config_path == resolved_path and _app_config_mtime is not None and current_mtime is not None and _app_config_mtime != current_mtime:
logger.info(
"Config file has been modified (mtime: %s -> %s), reloading AppConfig",
_app_config_mtime,
current_mtime,
)
_load_and_cache_app_config(str(resolved_path))
return _app_config
def reload_app_config(config_path: str | None = None) -> AppConfig:
"""Reload the config from file and update the cached instance.
This is useful when the config file has been modified and you want
to pick up the changes without restarting the application.
Args:
config_path: Optional path to config file. If not provided,
uses the default resolution strategy.
Returns:
The newly loaded AppConfig instance.
"""
return _load_and_cache_app_config(config_path)
def reset_app_config() -> None:
"""Reset the cached config instance.
This clears the singleton cache, causing the next call to
`get_app_config()` to reload from file. Useful for testing
or when switching between different configurations.
"""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = None
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = False
def set_app_config(config: AppConfig) -> None:
"""Set a custom config instance.
This allows injecting a custom or mock config for testing purposes.
Args:
config: The AppConfig instance to use.
"""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = config
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = True
def peek_current_app_config() -> AppConfig | None:
"""Return the runtime-scoped AppConfig override, if one is active."""
return _current_app_config.get()
def push_current_app_config(config: AppConfig) -> None:
"""Push a runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
_current_app_config_stack.set(stack + (_current_app_config.get(),))
_current_app_config.set(config)
def pop_current_app_config() -> None:
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
if not stack:
_current_app_config.set(None)
return
previous = stack[-1]
_current_app_config_stack.set(stack[:-1])
_current_app_config.set(previous)
@@ -2,7 +2,7 @@
from typing import Literal from typing import Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
CheckpointerType = Literal["memory", "sqlite", "postgres"] CheckpointerType = Literal["memory", "sqlite", "postgres"]
@@ -10,6 +10,8 @@ CheckpointerType = Literal["memory", "sqlite", "postgres"]
class CheckpointerConfig(BaseModel): class CheckpointerConfig(BaseModel):
"""Configuration for LangGraph state persistence checkpointer.""" """Configuration for LangGraph state persistence checkpointer."""
model_config = ConfigDict(frozen=True)
type: CheckpointerType = Field( type: CheckpointerType = Field(
description="Checkpointer backend type. " description="Checkpointer backend type. "
"'memory' is in-process only (lost on restart). " "'memory' is in-process only (lost on restart). "
@@ -23,24 +25,3 @@ class CheckpointerConfig(BaseModel):
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. " "For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.", "For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
) )
# Global configuration instance — None means no checkpointer is configured.
_checkpointer_config: CheckpointerConfig | None = None
def get_checkpointer_config() -> CheckpointerConfig | None:
"""Get the current checkpointer configuration, or None if not configured."""
return _checkpointer_config
def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
"""Set the checkpointer configuration."""
global _checkpointer_config
_checkpointer_config = config
def load_checkpointer_config_from_dict(config_dict: dict) -> None:
"""Load checkpointer configuration from a dictionary."""
global _checkpointer_config
_checkpointer_config = CheckpointerConfig(**config_dict)
@@ -34,10 +34,11 @@ from __future__ import annotations
import os import os
from typing import Literal from typing import Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
class DatabaseConfig(BaseModel): class DatabaseConfig(BaseModel):
model_config = ConfigDict(frozen=True)
backend: Literal["memory", "sqlite", "postgres"] = Field( backend: Literal["memory", "sqlite", "postgres"] = Field(
default="memory", default="memory",
description=("Storage backend for both checkpointer and application data. 'memory' for development (no persistence across restarts), 'sqlite' for single-node deployment, 'postgres' for production multi-node deployment."), description=("Storage backend for both checkpointer and application data. 'memory' for development (no persistence across restarts), 'sqlite' for single-node deployment, 'postgres' for production multi-node deployment."),
@@ -0,0 +1,55 @@
"""Per-invocation context for DeerFlow agent execution.
Injected via LangGraph Runtime. Middleware and tools access this
via Runtime[DeerFlowContext] parameters, through resolve_context().
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class DeerFlowContext:
"""Typed, immutable, per-invocation context injected via LangGraph Runtime.
Fields are all known at run start and never change during execution.
Mutable runtime state (e.g. sandbox_id) flows through ThreadState, not here.
"""
app_config: AppConfig
thread_id: str
agent_name: str | None = None
def resolve_context(runtime: Any) -> DeerFlowContext:
"""Return the typed DeerFlowContext that the runtime carries.
Gateway mode (``DeerFlowClient``, ``run_agent``) always attaches a typed
``DeerFlowContext`` via ``agent.astream(context=...)``; the LangGraph
Server path uses ``langgraph.json`` registration where the top-level
``make_lead_agent`` loads ``AppConfig`` from disk itself, so we still
arrive here with a typed context.
Only the dict/None shapes that legacy tests used to exercise would fall
through this function; we now reject them loudly instead of papering
over the missing context with an ambient ``AppConfig`` lookup.
"""
ctx = getattr(runtime, "context", None)
if isinstance(ctx, DeerFlowContext):
return ctx
raise RuntimeError(
"resolve_context: runtime.context is not a DeerFlowContext "
"(got type %s). Every entry point must attach one at invoke time — "
"Gateway/Client via agent.astream(context=DeerFlowContext(...)), "
"LangGraph Server via the make_lead_agent boundary that loads "
"AppConfig.from_file()." % type(ctx).__name__
)
@@ -11,6 +11,8 @@ from pydantic import BaseModel, ConfigDict, Field
class McpOAuthConfig(BaseModel): class McpOAuthConfig(BaseModel):
"""OAuth configuration for an MCP server (HTTP/SSE transports).""" """OAuth configuration for an MCP server (HTTP/SSE transports)."""
model_config = ConfigDict(extra="allow", frozen=True)
enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled") enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled")
token_url: str = Field(description="OAuth token endpoint URL") token_url: str = Field(description="OAuth token endpoint URL")
grant_type: Literal["client_credentials", "refresh_token"] = Field( grant_type: Literal["client_credentials", "refresh_token"] = Field(
@@ -28,12 +30,13 @@ class McpOAuthConfig(BaseModel):
default_token_type: str = Field(default="Bearer", description="Default token type when missing in token response") default_token_type: str = Field(default="Bearer", description="Default token type when missing in token response")
refresh_skew_seconds: int = Field(default=60, description="Refresh token this many seconds before expiry") refresh_skew_seconds: int = Field(default=60, description="Refresh token this many seconds before expiry")
extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint") extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint")
model_config = ConfigDict(extra="allow")
class McpServerConfig(BaseModel): class McpServerConfig(BaseModel):
"""Configuration for a single MCP server.""" """Configuration for a single MCP server."""
model_config = ConfigDict(extra="allow", frozen=True)
enabled: bool = Field(default=True, description="Whether this MCP server is enabled") enabled: bool = Field(default=True, description="Whether this MCP server is enabled")
type: str = Field(default="stdio", description="Transport type: 'stdio', 'sse', or 'http'") type: str = Field(default="stdio", description="Transport type: 'stdio', 'sse', or 'http'")
command: str | None = Field(default=None, description="Command to execute to start the MCP server (for stdio type)") command: str | None = Field(default=None, description="Command to execute to start the MCP server (for stdio type)")
@@ -43,12 +46,13 @@ class McpServerConfig(BaseModel):
headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)") headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)")
oauth: McpOAuthConfig | None = Field(default=None, description="OAuth configuration (for sse or http type)") oauth: McpOAuthConfig | None = Field(default=None, description="OAuth configuration (for sse or http type)")
description: str = Field(default="", description="Human-readable description of what this MCP server provides") description: str = Field(default="", description="Human-readable description of what this MCP server provides")
model_config = ConfigDict(extra="allow")
class SkillStateConfig(BaseModel): class SkillStateConfig(BaseModel):
"""Configuration for a single skill's state.""" """Configuration for a single skill's state."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(default=True, description="Whether this skill is enabled") enabled: bool = Field(default=True, description="Whether this skill is enabled")
@@ -64,7 +68,7 @@ class ExtensionsConfig(BaseModel):
default_factory=dict, default_factory=dict,
description="Map of skill name to state configuration", description="Map of skill name to state configuration",
) )
model_config = ConfigDict(extra="allow", populate_by_name=True) model_config = ConfigDict(extra="allow", frozen=True, populate_by_name=True)
@classmethod @classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path | None: def resolve_config_path(cls, config_path: str | None = None) -> Path | None:
@@ -195,62 +199,3 @@ class ExtensionsConfig(BaseModel):
# Default to enable for public & custom skill # Default to enable for public & custom skill
return skill_category in ("public", "custom") return skill_category in ("public", "custom")
return skill_config.enabled return skill_config.enabled
_extensions_config: ExtensionsConfig | None = None
def get_extensions_config() -> ExtensionsConfig:
"""Get the extensions config instance.
Returns a cached singleton instance. Use `reload_extensions_config()` to reload
from file, or `reset_extensions_config()` to clear the cache.
Returns:
The cached ExtensionsConfig instance.
"""
global _extensions_config
if _extensions_config is None:
_extensions_config = ExtensionsConfig.from_file()
return _extensions_config
def reload_extensions_config(config_path: str | None = None) -> ExtensionsConfig:
"""Reload the extensions config from file and update the cached instance.
This is useful when the config file has been modified and you want
to pick up the changes without restarting the application.
Args:
config_path: Optional path to extensions config file. If not provided,
uses the default resolution strategy.
Returns:
The newly loaded ExtensionsConfig instance.
"""
global _extensions_config
_extensions_config = ExtensionsConfig.from_file(config_path)
return _extensions_config
def reset_extensions_config() -> None:
"""Reset the cached extensions config instance.
This clears the singleton cache, causing the next call to
`get_extensions_config()` to reload from file. Useful for testing
or when switching between different configurations.
"""
global _extensions_config
_extensions_config = None
def set_extensions_config(config: ExtensionsConfig) -> None:
"""Set a custom extensions config instance.
This allows injecting a custom or mock config for testing purposes.
Args:
config: The ExtensionsConfig instance to use.
"""
global _extensions_config
_extensions_config = config
@@ -1,11 +1,13 @@
"""Configuration for pre-tool-call authorization.""" """Configuration for pre-tool-call authorization."""
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
class GuardrailProviderConfig(BaseModel): class GuardrailProviderConfig(BaseModel):
"""Configuration for a guardrail provider.""" """Configuration for a guardrail provider."""
model_config = ConfigDict(frozen=True)
use: str = Field(description="Class path (e.g. 'deerflow.guardrails.builtin:AllowlistProvider')") use: str = Field(description="Class path (e.g. 'deerflow.guardrails.builtin:AllowlistProvider')")
config: dict = Field(default_factory=dict, description="Provider-specific settings passed as kwargs") config: dict = Field(default_factory=dict, description="Provider-specific settings passed as kwargs")
@@ -18,31 +20,9 @@ class GuardrailsConfig(BaseModel):
agent's passport reference, and returns an allow/deny decision. agent's passport reference, and returns an allow/deny decision.
""" """
model_config = ConfigDict(frozen=True)
enabled: bool = Field(default=False, description="Enable guardrail middleware") enabled: bool = Field(default=False, description="Enable guardrail middleware")
fail_closed: bool = Field(default=True, description="Block tool calls if provider errors") fail_closed: bool = Field(default=True, description="Block tool calls if provider errors")
passport: str | None = Field(default=None, description="OAP passport path or hosted agent ID") passport: str | None = Field(default=None, description="OAP passport path or hosted agent ID")
provider: GuardrailProviderConfig | None = Field(default=None, description="Guardrail provider configuration") provider: GuardrailProviderConfig | None = Field(default=None, description="Guardrail provider configuration")
_guardrails_config: GuardrailsConfig | None = None
def get_guardrails_config() -> GuardrailsConfig:
"""Get the guardrails config, returning defaults if not loaded."""
global _guardrails_config
if _guardrails_config is None:
_guardrails_config = GuardrailsConfig()
return _guardrails_config
def load_guardrails_config_from_dict(data: dict) -> GuardrailsConfig:
"""Load guardrails config from a dict (called during AppConfig loading)."""
global _guardrails_config
_guardrails_config = GuardrailsConfig.model_validate(data)
return _guardrails_config
def reset_guardrails_config() -> None:
"""Reset the cached config instance. Used in tests to prevent singleton leaks."""
global _guardrails_config
_guardrails_config = None
@@ -1,11 +1,13 @@
"""Configuration for memory mechanism.""" """Configuration for memory mechanism."""
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
class MemoryConfig(BaseModel): class MemoryConfig(BaseModel):
"""Configuration for global memory mechanism.""" """Configuration for global memory mechanism."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field( enabled: bool = Field(
default=True, default=True,
description="Whether to enable memory mechanism", description="Whether to enable memory mechanism",
@@ -60,24 +62,3 @@ class MemoryConfig(BaseModel):
le=8000, le=8000,
description="Maximum tokens to use for memory injection", description="Maximum tokens to use for memory injection",
) )
# Global configuration instance
_memory_config: MemoryConfig = MemoryConfig()
def get_memory_config() -> MemoryConfig:
"""Get the current memory configuration."""
return _memory_config
def set_memory_config(config: MemoryConfig) -> None:
"""Set the memory configuration."""
global _memory_config
_memory_config = config
def load_memory_config_from_dict(config_dict: dict) -> None:
"""Load memory configuration from a dictionary."""
global _memory_config
_memory_config = MemoryConfig(**config_dict)
@@ -12,7 +12,7 @@ class ModelConfig(BaseModel):
description="Class path of the model provider(e.g. langchain_openai.ChatOpenAI)", description="Class path of the model provider(e.g. langchain_openai.ChatOpenAI)",
) )
model: str = Field(..., description="Model name") model: str = Field(..., description="Model name")
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow", frozen=True)
use_responses_api: bool | None = Field( use_responses_api: bool | None = Field(
default=None, default=None,
description="Whether to route OpenAI ChatOpenAI calls through the /v1/responses API", description="Whether to route OpenAI ChatOpenAI calls through the /v1/responses API",
@@ -15,10 +15,11 @@ from __future__ import annotations
from typing import Literal from typing import Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
class RunEventsConfig(BaseModel): class RunEventsConfig(BaseModel):
model_config = ConfigDict(frozen=True)
backend: Literal["memory", "db", "jsonl"] = Field( backend: Literal["memory", "db", "jsonl"] = Field(
default="memory", default="memory",
description="Storage backend for run events. 'memory' for development (no persistence), 'db' for production (SQL queries), 'jsonl' for lightweight single-node persistence.", description="Storage backend for run events. 'memory' for development (no persistence), 'db' for production (SQL queries), 'jsonl' for lightweight single-node persistence.",
@@ -4,6 +4,8 @@ from pydantic import BaseModel, ConfigDict, Field
class VolumeMountConfig(BaseModel): class VolumeMountConfig(BaseModel):
"""Configuration for a volume mount.""" """Configuration for a volume mount."""
model_config = ConfigDict(frozen=True)
host_path: str = Field(..., description="Path on the host machine") host_path: str = Field(..., description="Path on the host machine")
container_path: str = Field(..., description="Path inside the container") container_path: str = Field(..., description="Path inside the container")
read_only: bool = Field(default=False, description="Whether the mount is read-only") read_only: bool = Field(default=False, description="Whether the mount is read-only")
@@ -80,4 +82,4 @@ class SandboxConfig(BaseModel):
description="Maximum characters to keep from ls tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.", description="Maximum characters to keep from ls tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.",
) )
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow", frozen=True)
@@ -1,9 +1,11 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
class SkillEvolutionConfig(BaseModel): class SkillEvolutionConfig(BaseModel):
"""Configuration for agent-managed skill evolution.""" """Configuration for agent-managed skill evolution."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field( enabled: bool = Field(
default=False, default=False,
description="Whether the agent can create and modify skills under skills/custom.", description="Whether the agent can create and modify skills under skills/custom.",
@@ -1,6 +1,6 @@
from pathlib import Path from pathlib import Path
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
def _default_repo_root() -> Path: def _default_repo_root() -> Path:
@@ -11,6 +11,8 @@ def _default_repo_root() -> Path:
class SkillsConfig(BaseModel): class SkillsConfig(BaseModel):
"""Configuration for skills system""" """Configuration for skills system"""
model_config = ConfigDict(frozen=True)
path: str | None = Field( path: str | None = Field(
default=None, default=None,
description="Path to skills directory. If not specified, defaults to ../skills relative to backend directory", description="Path to skills directory. If not specified, defaults to ../skills relative to backend directory",
@@ -2,7 +2,7 @@
from typing import Literal from typing import Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
StreamBridgeType = Literal["memory", "redis"] StreamBridgeType = Literal["memory", "redis"]
@@ -10,6 +10,8 @@ StreamBridgeType = Literal["memory", "redis"]
class StreamBridgeConfig(BaseModel): class StreamBridgeConfig(BaseModel):
"""Configuration for the stream bridge that connects agent workers to SSE endpoints.""" """Configuration for the stream bridge that connects agent workers to SSE endpoints."""
model_config = ConfigDict(frozen=True)
type: StreamBridgeType = Field( type: StreamBridgeType = Field(
default="memory", default="memory",
description="Stream bridge backend type. 'memory' uses in-process asyncio.Queue (single-process only). 'redis' uses Redis Streams (planned for Phase 2, not yet implemented).", description="Stream bridge backend type. 'memory' uses in-process asyncio.Queue (single-process only). 'redis' uses Redis Streams (planned for Phase 2, not yet implemented).",
@@ -22,25 +24,3 @@ class StreamBridgeConfig(BaseModel):
default=256, default=256,
description="Maximum number of events buffered per run in the memory bridge.", description="Maximum number of events buffered per run in the memory bridge.",
) )
# Global configuration instance — None means no stream bridge is configured
# (falls back to memory with defaults).
_stream_bridge_config: StreamBridgeConfig | None = None
def get_stream_bridge_config() -> StreamBridgeConfig | None:
"""Get the current stream bridge configuration, or None if not configured."""
return _stream_bridge_config
def set_stream_bridge_config(config: StreamBridgeConfig | None) -> None:
"""Set the stream bridge configuration."""
global _stream_bridge_config
_stream_bridge_config = config
def load_stream_bridge_config_from_dict(config_dict: dict) -> None:
"""Load stream bridge configuration from a dictionary."""
global _stream_bridge_config
_stream_bridge_config = StreamBridgeConfig(**config_dict)
@@ -1,15 +1,13 @@
"""Configuration for the subagent system loaded from config.yaml.""" """Configuration for the subagent system loaded from config.yaml."""
import logging from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class SubagentOverrideConfig(BaseModel): class SubagentOverrideConfig(BaseModel):
"""Per-agent configuration overrides.""" """Per-agent configuration overrides."""
model_config = ConfigDict(frozen=True)
timeout_seconds: int | None = Field( timeout_seconds: int | None = Field(
default=None, default=None,
ge=1, ge=1,
@@ -71,6 +69,8 @@ class CustomSubagentConfig(BaseModel):
class SubagentsAppConfig(BaseModel): class SubagentsAppConfig(BaseModel):
"""Configuration for the subagent system.""" """Configuration for the subagent system."""
model_config = ConfigDict(frozen=True)
timeout_seconds: int = Field( timeout_seconds: int = Field(
default=900, default=900,
ge=1, ge=1,
@@ -140,48 +140,3 @@ class SubagentsAppConfig(BaseModel):
if override is not None and override.skills is not None: if override is not None and override.skills is not None:
return override.skills return override.skills
return None return None
_subagents_config: SubagentsAppConfig = SubagentsAppConfig()
def get_subagents_app_config() -> SubagentsAppConfig:
"""Get the current subagents configuration."""
return _subagents_config
def load_subagents_config_from_dict(config_dict: dict) -> None:
"""Load subagents configuration from a dictionary."""
global _subagents_config
_subagents_config = SubagentsAppConfig(**config_dict)
overrides_summary = {}
for name, override in _subagents_config.agents.items():
parts = []
if override.timeout_seconds is not None:
parts.append(f"timeout={override.timeout_seconds}s")
if override.max_turns is not None:
parts.append(f"max_turns={override.max_turns}")
if override.model is not None:
parts.append(f"model={override.model}")
if override.skills is not None:
parts.append(f"skills={override.skills}")
if parts:
overrides_summary[name] = ", ".join(parts)
custom_agents_names = list(_subagents_config.custom_agents.keys())
if overrides_summary or custom_agents_names:
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, per-agent overrides=%s, custom_agents=%s",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
overrides_summary or "none",
custom_agents_names or "none",
)
else:
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, no per-agent overrides",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
)
@@ -2,7 +2,7 @@
from typing import Literal from typing import Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
ContextSizeType = Literal["fraction", "tokens", "messages"] ContextSizeType = Literal["fraction", "tokens", "messages"]
@@ -10,6 +10,8 @@ ContextSizeType = Literal["fraction", "tokens", "messages"]
class ContextSize(BaseModel): class ContextSize(BaseModel):
"""Context size specification for trigger or keep parameters.""" """Context size specification for trigger or keep parameters."""
model_config = ConfigDict(frozen=True)
type: ContextSizeType = Field(description="Type of context size specification") type: ContextSizeType = Field(description="Type of context size specification")
value: int | float = Field(description="Value for the context size specification") value: int | float = Field(description="Value for the context size specification")
@@ -21,6 +23,8 @@ class ContextSize(BaseModel):
class SummarizationConfig(BaseModel): class SummarizationConfig(BaseModel):
"""Configuration for automatic conversation summarization.""" """Configuration for automatic conversation summarization."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field( enabled: bool = Field(
default=False, default=False,
description="Whether to enable automatic conversation summarization", description="Whether to enable automatic conversation summarization",
@@ -70,24 +74,3 @@ class SummarizationConfig(BaseModel):
default_factory=lambda: ["read_file", "read", "view", "cat"], default_factory=lambda: ["read_file", "read", "view", "cat"],
description="Tool names treated as skill file reads when preserving recently-loaded skills across summarization.", description="Tool names treated as skill file reads when preserving recently-loaded skills across summarization.",
) )
# Global configuration instance
_summarization_config: SummarizationConfig = SummarizationConfig()
def get_summarization_config() -> SummarizationConfig:
"""Get the current summarization configuration."""
return _summarization_config
def set_summarization_config(config: SummarizationConfig) -> None:
"""Set the summarization configuration."""
global _summarization_config
_summarization_config = config
def load_summarization_config_from_dict(config_dict: dict) -> None:
"""Load summarization configuration from a dictionary."""
global _summarization_config
_summarization_config = SummarizationConfig(**config_dict)
@@ -1,11 +1,13 @@
"""Configuration for automatic thread title generation.""" """Configuration for automatic thread title generation."""
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
class TitleConfig(BaseModel): class TitleConfig(BaseModel):
"""Configuration for automatic thread title generation.""" """Configuration for automatic thread title generation."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field( enabled: bool = Field(
default=True, default=True,
description="Whether to enable automatic title generation", description="Whether to enable automatic title generation",
@@ -30,24 +32,3 @@ class TitleConfig(BaseModel):
default=("Generate a concise title (max {max_words} words) for this conversation.\nUser: {user_msg}\nAssistant: {assistant_msg}\n\nReturn ONLY the title, no quotes, no explanation."), default=("Generate a concise title (max {max_words} words) for this conversation.\nUser: {user_msg}\nAssistant: {assistant_msg}\n\nReturn ONLY the title, no quotes, no explanation."),
description="Prompt template for title generation", description="Prompt template for title generation",
) )
# Global configuration instance
_title_config: TitleConfig = TitleConfig()
def get_title_config() -> TitleConfig:
"""Get the current title configuration."""
return _title_config
def set_title_config(config: TitleConfig) -> None:
"""Set the title configuration."""
global _title_config
_title_config = config
def load_title_config_from_dict(config_dict: dict) -> None:
"""Load title configuration from a dictionary."""
global _title_config
_title_config = TitleConfig(**config_dict)
@@ -1,7 +1,9 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
class TokenUsageConfig(BaseModel): class TokenUsageConfig(BaseModel):
"""Configuration for token usage tracking.""" """Configuration for token usage tracking."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(default=False, description="Enable token usage tracking middleware") enabled: bool = Field(default=False, description="Enable token usage tracking middleware")
@@ -5,7 +5,7 @@ class ToolGroupConfig(BaseModel):
"""Config section for a tool group""" """Config section for a tool group"""
name: str = Field(..., description="Unique name for the tool group") name: str = Field(..., description="Unique name for the tool group")
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow", frozen=True)
class ToolConfig(BaseModel): class ToolConfig(BaseModel):
@@ -17,4 +17,4 @@ class ToolConfig(BaseModel):
..., ...,
description="Variable name of the tool provider(e.g. deerflow.sandbox.tools:bash_tool)", description="Variable name of the tool provider(e.g. deerflow.sandbox.tools:bash_tool)",
) )
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow", frozen=True)
@@ -1,6 +1,6 @@
"""Configuration for deferred tool loading via tool_search.""" """Configuration for deferred tool loading via tool_search."""
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
class ToolSearchConfig(BaseModel): class ToolSearchConfig(BaseModel):
@@ -11,25 +11,9 @@ class ToolSearchConfig(BaseModel):
via the tool_search tool at runtime. via the tool_search tool at runtime.
""" """
model_config = ConfigDict(frozen=True)
enabled: bool = Field( enabled: bool = Field(
default=False, default=False,
description="Defer tools and enable tool_search", description="Defer tools and enable tool_search",
) )
_tool_search_config: ToolSearchConfig | None = None
def get_tool_search_config() -> ToolSearchConfig:
"""Get the tool search config, loading from AppConfig if needed."""
global _tool_search_config
if _tool_search_config is None:
_tool_search_config = ToolSearchConfig()
return _tool_search_config
def load_tool_search_config_from_dict(data: dict) -> ToolSearchConfig:
"""Load tool search config from a dict (called during AppConfig loading)."""
global _tool_search_config
_tool_search_config = ToolSearchConfig.model_validate(data)
return _tool_search_config
@@ -1,7 +1,7 @@
import os import os
import threading import threading
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
_config_lock = threading.Lock() _config_lock = threading.Lock()
@@ -9,6 +9,8 @@ _config_lock = threading.Lock()
class LangSmithTracingConfig(BaseModel): class LangSmithTracingConfig(BaseModel):
"""Configuration for LangSmith tracing.""" """Configuration for LangSmith tracing."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(...) enabled: bool = Field(...)
api_key: str | None = Field(...) api_key: str | None = Field(...)
project: str = Field(...) project: str = Field(...)
@@ -26,6 +28,8 @@ class LangSmithTracingConfig(BaseModel):
class LangfuseTracingConfig(BaseModel): class LangfuseTracingConfig(BaseModel):
"""Configuration for Langfuse tracing.""" """Configuration for Langfuse tracing."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(...) enabled: bool = Field(...)
public_key: str | None = Field(...) public_key: str | None = Field(...)
secret_key: str | None = Field(...) secret_key: str | None = Field(...)
@@ -50,6 +54,8 @@ class LangfuseTracingConfig(BaseModel):
class TracingConfig(BaseModel): class TracingConfig(BaseModel):
"""Tracing configuration for supported providers.""" """Tracing configuration for supported providers."""
model_config = ConfigDict(frozen=True)
langsmith: LangSmithTracingConfig = Field(...) langsmith: LangSmithTracingConfig = Field(...)
langfuse: LangfuseTracingConfig = Field(...) langfuse: LangfuseTracingConfig = Field(...)
@@ -2,7 +2,7 @@ import logging
from langchain.chat_models import BaseChatModel from langchain.chat_models import BaseChatModel
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_class from deerflow.reflection import resolve_class
from deerflow.tracing import build_tracing_callbacks from deerflow.tracing import build_tracing_callbacks
@@ -46,16 +46,23 @@ def _enable_stream_usage_by_default(model_use_path: str, model_settings_from_con
model_settings_from_config["stream_usage"] = True model_settings_from_config["stream_usage"] = True
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, **kwargs) -> BaseChatModel: def create_chat_model(
name: str | None = None,
thinking_enabled: bool = False,
*,
app_config: "AppConfig",
**kwargs,
) -> BaseChatModel:
"""Create a chat model instance from the config. """Create a chat model instance from the config.
Args: Args:
name: The name of the model to create. If None, the first model in the config will be used. name: The name of the model to create. If None, the first model in the config will be used.
app_config: Application config required.
Returns: Returns:
A chat model instance. A chat model instance.
""" """
config = get_app_config() config = app_config
if name is None: if name is None:
name = config.models[0].name name = config.models[0].name
model_config = config.get_model_config(name) model_config = config.get_model_config(name)
@@ -13,7 +13,9 @@ from deerflow.persistence.base import Base
class FeedbackRow(Base): class FeedbackRow(Base):
__tablename__ = "feedback" __tablename__ = "feedback"
__table_args__ = (UniqueConstraint("thread_id", "run_id", "user_id", name="uq_feedback_thread_run_user"),) __table_args__ = (
UniqueConstraint("thread_id", "run_id", "user_id", name="uq_feedback_thread_run_user"),
)
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True) feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True) run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
@@ -18,9 +18,7 @@ from deerflow.persistence.base import Base
# Import all models so metadata is populated. # Import all models so metadata is populated.
try: try:
import deerflow.persistence.models as models # register ORM models with Base.metadata import deerflow.persistence.models # noqa: F401 — register ORM models with Base.metadata
_ = models
except ImportError: except ImportError:
# Models not available — migration will work with existing metadata only. # Models not available — migration will work with existing metadata only.
logging.getLogger(__name__).warning("Could not import deerflow.persistence.models; Alembic may not detect all tables") logging.getLogger(__name__).warning("Could not import deerflow.persistence.models; Alembic may not detect all tables")
@@ -24,7 +24,7 @@ from collections.abc import AsyncIterator
from langgraph.types import Checkpointer from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.runtime.checkpointer.provider import ( from deerflow.runtime.checkpointer.provider import (
POSTGRES_CONN_REQUIRED, POSTGRES_CONN_REQUIRED,
POSTGRES_INSTALL, POSTGRES_INSTALL,
@@ -123,11 +123,11 @@ async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpoi
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def make_checkpointer() -> AsyncIterator[Checkpointer]: async def make_checkpointer(app_config: AppConfig) -> AsyncIterator[Checkpointer]:
"""Async context manager that yields a checkpointer for the caller's lifetime. """Async context manager that yields a checkpointer for the caller's lifetime.
Resources are opened on enter and closed on exit -- no global state:: Resources are opened on enter and closed on exit -- no global state::
async with make_checkpointer() as checkpointer: async with make_checkpointer(app_config) as checkpointer:
app.state.checkpointer = checkpointer app.state.checkpointer = checkpointer
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*. Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
@@ -138,16 +138,14 @@ async def make_checkpointer() -> AsyncIterator[Checkpointer]:
3. Default InMemorySaver 3. Default InMemorySaver
""" """
config = get_app_config()
# Legacy: standalone checkpointer config takes precedence # Legacy: standalone checkpointer config takes precedence
if config.checkpointer is not None: if app_config.checkpointer is not None:
async with _async_checkpointer(config.checkpointer) as saver: async with _async_checkpointer(app_config.checkpointer) as saver:
yield saver yield saver
return return
# Unified database config # Unified database config
db_config = getattr(config, "database", None) db_config = getattr(app_config, "database", None)
if db_config is not None and db_config.backend != "memory": if db_config is not None and db_config.backend != "memory":
async with _async_checkpointer_from_database(db_config) as saver: async with _async_checkpointer_from_database(db_config) as saver:
yield saver yield saver
@@ -25,7 +25,7 @@ from collections.abc import Iterator
from langgraph.types import Checkpointer from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.config.checkpointer_config import CheckpointerConfig from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
@@ -100,10 +100,13 @@ _checkpointer: Checkpointer | None = None
_checkpointer_ctx = None # open context manager keeping the connection alive _checkpointer_ctx = None # open context manager keeping the connection alive
def get_checkpointer() -> Checkpointer: def get_checkpointer(app_config: AppConfig) -> Checkpointer:
"""Return the global sync checkpointer singleton, creating it on first call. """Return the global sync checkpointer singleton, creating it on first call.
Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*. Returns an ``InMemorySaver`` only when ``checkpointer`` is explicitly
absent from config.yaml. Any other failure (missing config, invalid
backend, connection error) propagates silent degradation to in-memory
would drop persistent-run state on process restart.
Raises: Raises:
ImportError: If the required package for the configured backend is not installed. ImportError: If the required package for the configured backend is not installed.
@@ -114,25 +117,7 @@ def get_checkpointer() -> Checkpointer:
if _checkpointer is not None: if _checkpointer is not None:
return _checkpointer return _checkpointer
# Ensure app config is loaded before checking checkpointer config config = app_config.checkpointer
# This prevents returning InMemorySaver when config.yaml actually has a checkpointer section
# but hasn't been loaded yet
from deerflow.config.app_config import _app_config
from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config()
if config is None and _app_config is None:
# Only load app config lazily when neither the app config nor an explicit
# checkpointer config has been initialized yet. This keeps tests that
# intentionally set the global checkpointer config isolated from any
# ambient config.yaml on disk.
try:
get_app_config()
except FileNotFoundError:
# In test environments without config.yaml, this is expected.
pass
config = get_checkpointer_config()
if config is None: if config is None:
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
@@ -168,25 +153,23 @@ def reset_checkpointer() -> None:
@contextlib.contextmanager @contextlib.contextmanager
def checkpointer_context() -> Iterator[Checkpointer]: def checkpointer_context(app_config: AppConfig) -> Iterator[Checkpointer]:
"""Sync context manager that yields a checkpointer and cleans up on exit. """Sync context manager that yields a checkpointer and cleans up on exit.
Unlike :func:`get_checkpointer`, this does **not** cache the instance Unlike :func:`get_checkpointer`, this does **not** cache the instance
each ``with`` block creates and destroys its own connection. Use it in each ``with`` block creates and destroys its own connection. Use it in
CLI scripts or tests where you want deterministic cleanup:: CLI scripts or tests where you want deterministic cleanup::
with checkpointer_context() as cp: with checkpointer_context(app_config) as cp:
graph.invoke(input, config={"configurable": {"thread_id": "1"}}) graph.invoke(input, config={"configurable": {"thread_id": "1"}})
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*. Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
""" """
if app_config.checkpointer is None:
config = get_app_config()
if config.checkpointer is None:
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver() yield InMemorySaver()
return return
with _sync_checkpointer_cm(config.checkpointer) as saver: with _sync_checkpointer_cm(app_config.checkpointer) as saver:
yield saver yield saver
@@ -6,10 +6,7 @@ handles token usage accumulation.
Key design decisions: Key design decisions:
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end - on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
- on_chat_model_start captures structured prompts as llm_request (OpenAI format) and - on_chat_model_start captures structured prompts as llm_request (OpenAI format)
extracts the first human message for run.input, because it is more reliable than
on_chain_start (fires on every node) messages here are fully structured.
- on_chain_start with parent_run_id=None emits a run.start trace marking root invocation.
- on_llm_end emits llm_response in OpenAI Chat Completions format - on_llm_end emits llm_response in OpenAI Chat Completions format
- Token usage accumulated in memory, written to RunRow on run completion - Token usage accumulated in memory, written to RunRow on run completion
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name}) - Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
@@ -21,12 +18,10 @@ import asyncio
import logging import logging
import time import time
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any
from uuid import UUID from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage
from langgraph.types import Command
if TYPE_CHECKING: if TYPE_CHECKING:
from deerflow.runtime.events.store.base import RunEventStore from deerflow.runtime.events.store.base import RunEventStore
@@ -77,39 +72,34 @@ class RunJournal(BaseCallbackHandler):
# LLM request/response tracking # LLM request/response tracking
self._llm_call_index = 0 self._llm_call_index = 0
self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages
self._cached_models: dict[str, str] = {} # langchain run_id -> model name
# Tool call ID cache
self._tool_call_ids: dict[str, str] = {} # langchain run_id -> tool_call_id
# -- Lifecycle callbacks -- # -- Lifecycle callbacks --
def on_chain_start( def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
self, if kwargs.get("parent_run_id") is not None:
serialized: dict[str, Any], return
inputs: dict[str, Any], self._put(
*, event_type="run_start",
run_id: UUID, category="lifecycle",
parent_run_id: UUID | None = None, metadata={"input_preview": str(inputs)[:500]},
tags: list[str] | None = None, )
metadata: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
caller = self._identify_caller(tags)
if parent_run_id is None:
# Root graph invocation — emit a single trace event for the run start.
chain_name = (serialized or {}).get("name", "unknown")
self._put(
event_type="run.start",
category="trace",
content={"chain": chain_name},
metadata={"caller": caller, **(metadata or {})},
)
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None: def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"}) if kwargs.get("parent_run_id") is not None:
return
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
self._flush_sync() self._flush_sync()
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
if kwargs.get("parent_run_id") is not None:
return
self._put( self._put(
event_type="run.error", event_type="run_error",
category="error", category="lifecycle",
content=str(error), content=str(error),
metadata={"error_type": type(error).__name__}, metadata={"error_type": type(error).__name__},
) )
@@ -117,132 +107,266 @@ class RunJournal(BaseCallbackHandler):
# -- LLM callbacks -- # -- LLM callbacks --
def on_chat_model_start( def on_chat_model_start(self, serialized: dict, messages: list[list], *, run_id: UUID, **kwargs: Any) -> None:
self, """Capture structured prompt messages for llm_request event."""
serialized: dict, from deerflow.runtime.converters import langchain_messages_to_openai
messages: list[list[BaseMessage]],
*,
run_id: UUID,
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
"""Capture structured prompt messages for llm_request event.
This is also the canonical place to extract the first human message:
messages are fully structured here, it fires only on real LLM calls,
and the content is never compressed by checkpoint trimming.
"""
rid = str(run_id) rid = str(run_id)
self._llm_start_times[rid] = time.monotonic() self._llm_start_times[rid] = time.monotonic()
self._llm_call_index += 1 self._llm_call_index += 1
# Mark this run_id as seen so on_llm_end knows not to increment again.
self._cached_prompts[rid] = []
logger.info(f"on_chat_model_start {run_id}: tags={tags} serialized={serialized} messages={messages}") model_name = serialized.get("name", "")
self._cached_models[rid] = model_name
# Capture the first human message sent to any LLM in this run. # Convert the first message list (LangChain passes list-of-lists)
if not self._first_human_msg and not messages: prompt_msgs = messages[0] if messages else []
for batch in messages.reversed(): openai_msgs = langchain_messages_to_openai(prompt_msgs)
for m in batch.reversed(): self._cached_prompts[rid] = openai_msgs
if isinstance(m, HumanMessage) and m.name != "summary":
caller = self._identify_caller(tags)
self.set_first_human_message(m.text)
self._put(
event_type="llm.human.input",
category="message",
content=m.model_dump(),
metadata={"caller": caller},
)
break
if self._first_human_msg:
break
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, parent_run_id: UUID | None = None, tags: list[str] | None = None, metadata: dict[str, Any] | None = None, **kwargs: Any) -> None: caller = self._identify_caller(kwargs)
self._put(
event_type="llm_request",
category="trace",
content={"model": model_name, "messages": openai_msgs},
metadata={"caller": caller, "llm_call_index": self._llm_call_index},
)
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
# Fallback: on_chat_model_start is preferred. This just tracks latency. # Fallback: on_chat_model_start is preferred. This just tracks latency.
self._llm_start_times[str(run_id)] = time.monotonic() self._llm_start_times[str(run_id)] = time.monotonic()
def on_llm_end(self, response, *, run_id, parent_run_id, tags, **kwargs) -> None: def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
messages: list[AnyMessage] = [] from deerflow.runtime.converters import langchain_to_openai_completion
logger.info(f"on_llm_end {run_id}: response: {tags} {kwargs}")
for generation in response.generations:
for gen in generation:
if hasattr(gen, "message"):
messages.append(gen.message)
else:
logger.warning(f"on_llm_end {run_id}: generation has no message attribute: {gen}")
for message in messages: try:
caller = self._identify_caller(tags) message = response.generations[0][0].message
except (IndexError, AttributeError):
logger.debug("on_llm_end: could not extract message from response")
return
# Latency caller = self._identify_caller(kwargs)
rid = str(run_id)
start = self._llm_start_times.pop(rid, None)
latency_ms = int((time.monotonic() - start) * 1000) if start else None
# Token usage from message # Latency
usage = getattr(message, "usage_metadata", None) rid = str(run_id)
usage_dict = dict(usage) if usage else {} start = self._llm_start_times.pop(rid, None)
latency_ms = int((time.monotonic() - start) * 1000) if start else None
# Resolve call index # Token usage from message
usage = getattr(message, "usage_metadata", None)
usage_dict = dict(usage) if usage else {}
# Resolve call index
call_index = self._llm_call_index
if rid not in self._cached_prompts:
# Fallback: on_chat_model_start was not called
self._llm_call_index += 1
call_index = self._llm_call_index call_index = self._llm_call_index
if rid not in self._cached_prompts:
# Fallback: on_chat_model_start was not called
self._llm_call_index += 1
call_index = self._llm_call_index
# Trace event: llm_response (OpenAI completion format) # Clean up caches
self._put( self._cached_prompts.pop(rid, None)
event_type="llm.ai.response", self._cached_models.pop(rid, None)
category="message",
content=message.model_dump(),
metadata={
"caller": caller,
"usage": usage_dict,
"latency_ms": latency_ms,
"llm_call_index": call_index,
},
)
# Token accumulation # Trace event: llm_response (OpenAI completion format)
if self._track_tokens: content = getattr(message, "content", "")
input_tk = usage_dict.get("input_tokens", 0) or 0 self._put(
output_tk = usage_dict.get("output_tokens", 0) or 0 event_type="llm_response",
total_tk = usage_dict.get("total_tokens", 0) or 0 category="trace",
if total_tk == 0: content=langchain_to_openai_completion(message),
total_tk = input_tk + output_tk metadata={
if total_tk > 0: "caller": caller,
self._total_input_tokens += input_tk "usage": usage_dict,
self._total_output_tokens += output_tk "latency_ms": latency_ms,
self._total_tokens += total_tk "llm_call_index": call_index,
self._llm_call_count += 1 },
)
# Message events: only lead_agent gets message-category events.
# Content uses message.model_dump() to align with checkpoint format.
tool_calls = getattr(message, "tool_calls", None) or []
if caller == "lead_agent":
resp_meta = getattr(message, "response_metadata", None) or {}
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
if tool_calls:
# ai_tool_call: agent decided to use tools
self._put(
event_type="ai_tool_call",
category="message",
content=message.model_dump(),
metadata={"model_name": model_name, "finish_reason": "tool_calls"},
)
elif isinstance(content, str) and content:
# ai_message: final text reply
self._put(
event_type="ai_message",
category="message",
content=message.model_dump(),
metadata={"model_name": model_name, "finish_reason": "stop"},
)
self._last_ai_msg = content
self._msg_count += 1
# Token accumulation
if self._track_tokens:
input_tk = usage_dict.get("input_tokens", 0) or 0
output_tk = usage_dict.get("output_tokens", 0) or 0
total_tk = usage_dict.get("total_tokens", 0) or 0
if total_tk == 0:
total_tk = input_tk + output_tk
if total_tk > 0:
self._total_input_tokens += input_tk
self._total_output_tokens += output_tk
self._total_tokens += total_tk
self._llm_call_count += 1
if caller.startswith("subagent:"):
self._subagent_tokens += total_tk
elif caller.startswith("middleware:"):
self._middleware_tokens += total_tk
else:
self._lead_agent_tokens += total_tk
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
self._llm_start_times.pop(str(run_id), None) self._llm_start_times.pop(str(run_id), None)
self._put(event_type="llm.error", category="trace", content=str(error)) self._put(event_type="llm_error", category="trace", content=str(error))
def on_tool_start(self, serialized, input_str, *, run_id, parent_run_id=None, tags=None, metadata=None, inputs=None, **kwargs): # -- Tool callbacks --
"""Handle tool start event, cache tool call ID for later correlation"""
tool_call_id = str(run_id)
logger.info(f"Tool start for node {run_id}, tool_call_id={tool_call_id}, tags={tags}, metadata={metadata}")
def on_tool_end(self, output, *, run_id, parent_run_id=None, **kwargs): def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None:
"""Handle tool end event, append message and clear node data""" tool_call_id = kwargs.get("tool_call_id")
try: if tool_call_id:
if isinstance(output, ToolMessage): self._tool_call_ids[str(run_id)] = tool_call_id
msg = cast(ToolMessage, output) self._put(
self._put(event_type="llm.tool.result", category="message", content=msg.model_dump()) event_type="tool_start",
elif isinstance(output, Command): category="trace",
cmd = cast(Command, output) metadata={
messages = cmd.update.get("messages", []) "tool_name": serialized.get("name", ""),
for message in messages: "tool_call_id": tool_call_id,
if isinstance(message, BaseMessage): "args": str(input_str)[:2000],
self._put(event_type="llm.tool.result", category="message", content=message.model_dump()) },
else: )
logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}")
else: def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
logger.warning(f"on_tool_end {run_id}: output is not ToolMessage: {type(output)}") from langchain_core.messages import ToolMessage
finally: from langgraph.types import Command
logger.info(f"Tool end for node {run_id}")
# Tools that update graph state return a ``Command`` (e.g.
# ``present_files``). LangGraph later unwraps the inner ToolMessage
# into checkpoint state, so to stay checkpoint-aligned we must
# extract it here rather than storing ``str(Command(...))``.
if isinstance(output, Command):
update = getattr(output, "update", None) or {}
inner_msgs = update.get("messages") if isinstance(update, dict) else None
if isinstance(inner_msgs, list):
inner_tool_msg = next((m for m in inner_msgs if isinstance(m, ToolMessage)), None)
if inner_tool_msg is not None:
output = inner_tool_msg
# Extract fields from ToolMessage object when LangChain provides one.
# LangChain's _format_output wraps tool results into a ToolMessage
# with tool_call_id, name, status, and artifact — more complete than
# what kwargs alone provides.
if isinstance(output, ToolMessage):
tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = output.name or kwargs.get("name", "")
status = getattr(output, "status", "success") or "success"
content_str = output.content if isinstance(output.content, str) else str(output.content)
# Use model_dump() for checkpoint-aligned message content.
# Override tool_call_id if it was resolved from cache.
msg_content = output.model_dump()
if msg_content.get("tool_call_id") != tool_call_id:
msg_content["tool_call_id"] = tool_call_id
else:
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
status = "success"
content_str = str(output)
# Construct checkpoint-aligned dict when output is a plain string.
msg_content = ToolMessage(
content=content_str,
tool_call_id=tool_call_id or "",
name=tool_name,
status=status,
).model_dump()
# Trace event (always)
self._put(
event_type="tool_end",
category="trace",
content=content_str,
metadata={
"tool_name": tool_name,
"tool_call_id": tool_call_id,
"status": status,
},
)
# Message event: tool_result (checkpoint-aligned model_dump format)
self._put(
event_type="tool_result",
category="message",
content=msg_content,
metadata={"tool_name": tool_name, "status": status},
)
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
from langchain_core.messages import ToolMessage
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
# Trace event
self._put(
event_type="tool_error",
category="trace",
content=str(error),
metadata={
"tool_name": tool_name,
"tool_call_id": tool_call_id,
},
)
# Message event: tool_result with error status (checkpoint-aligned)
msg_content = ToolMessage(
content=str(error),
tool_call_id=tool_call_id or "",
name=tool_name,
status="error",
).model_dump()
self._put(
event_type="tool_result",
category="message",
content=msg_content,
metadata={"tool_name": tool_name, "status": "error"},
)
# -- Custom event callback --
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
from deerflow.runtime.serialization import serialize_lc_object
if name == "summarization":
data_dict = data if isinstance(data, dict) else {}
self._put(
event_type="summarization",
category="trace",
content=data_dict.get("summary", ""),
metadata={
"replaced_message_ids": data_dict.get("replaced_message_ids", []),
"replaced_count": data_dict.get("replaced_count", 0),
},
)
self._put(
event_type="middleware:summarize",
category="middleware",
content={"role": "system", "content": data_dict.get("summary", "")},
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
)
else:
event_data = serialize_lc_object(data) if not isinstance(data, dict) else data
self._put(
event_type=name,
category="trace",
metadata=event_data if isinstance(event_data, dict) else {"data": event_data},
)
# -- Internal methods -- # -- Internal methods --
@@ -307,9 +431,8 @@ class RunJournal(BaseCallbackHandler):
if exc: if exc:
logger.warning("Journal flush task failed: %s", exc) logger.warning("Journal flush task failed: %s", exc)
def _identify_caller(self, tags: list[str] | None, **kwargs) -> str: def _identify_caller(self, kwargs: dict) -> str:
_tags = tags or kwargs.get("tags", []) for tag in kwargs.get("tags") or []:
for tag in _tags:
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"): if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
return tag return tag
# Default to lead_agent: the main agent graph does not inject # Default to lead_agent: the main agent graph does not inject
@@ -54,7 +54,7 @@ class RunManager:
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._store = store self._store = store
async def _persist_to_store(self, record: RunRecord) -> None: async def _persist_to_store(self, record: RunRecord, *, follow_up_to_run_id: str | None = None) -> None:
"""Best-effort persist run record to backing store.""" """Best-effort persist run record to backing store."""
if self._store is None: if self._store is None:
return return
@@ -68,6 +68,7 @@ class RunManager:
metadata=record.metadata or {}, metadata=record.metadata or {},
kwargs=record.kwargs or {}, kwargs=record.kwargs or {},
created_at=record.created_at, created_at=record.created_at,
follow_up_to_run_id=follow_up_to_run_id,
) )
except Exception: except Exception:
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True) logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
@@ -89,6 +90,7 @@ class RunManager:
metadata: dict | None = None, metadata: dict | None = None,
kwargs: dict | None = None, kwargs: dict | None = None,
multitask_strategy: str = "reject", multitask_strategy: str = "reject",
follow_up_to_run_id: str | None = None,
) -> RunRecord: ) -> RunRecord:
"""Create a new pending run and register it.""" """Create a new pending run and register it."""
run_id = str(uuid.uuid4()) run_id = str(uuid.uuid4())
@@ -107,7 +109,7 @@ class RunManager:
) )
async with self._lock: async with self._lock:
self._runs[run_id] = record self._runs[run_id] = record
await self._persist_to_store(record) await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record return record
@@ -174,6 +176,7 @@ class RunManager:
metadata: dict | None = None, metadata: dict | None = None,
kwargs: dict | None = None, kwargs: dict | None = None,
multitask_strategy: str = "reject", multitask_strategy: str = "reject",
follow_up_to_run_id: str | None = None,
) -> RunRecord: ) -> RunRecord:
"""Atomically check for inflight runs and create a new one. """Atomically check for inflight runs and create a new one.
@@ -227,7 +230,7 @@ class RunManager:
) )
self._runs[run_id] = record self._runs[run_id] = record
await self._persist_to_store(record) await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record return record
@@ -29,6 +29,7 @@ class RunStore(abc.ABC):
kwargs: dict[str, Any] | None = None, kwargs: dict[str, Any] | None = None,
error: str | None = None, error: str | None = None,
created_at: str | None = None, created_at: str | None = None,
follow_up_to_run_id: str | None = None,
) -> None: ) -> None:
pass pass
@@ -28,6 +28,7 @@ class MemoryRunStore(RunStore):
kwargs=None, kwargs=None,
error=None, error=None,
created_at=None, created_at=None,
follow_up_to_run_id=None,
): ):
now = datetime.now(UTC).isoformat() now = datetime.now(UTC).isoformat()
self._runs[run_id] = { self._runs[run_id] = {
@@ -40,6 +41,7 @@ class MemoryRunStore(RunStore):
"metadata": metadata or {}, "metadata": metadata or {},
"kwargs": kwargs or {}, "kwargs": kwargs or {},
"error": error, "error": error,
"follow_up_to_run_id": follow_up_to_run_id,
"created_at": created_at or now, "created_at": created_at or now,
"updated_at": now, "updated_at": now,
} }
@@ -25,6 +25,8 @@ from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.runtime.serialization import serialize from deerflow.runtime.serialization import serialize
from deerflow.runtime.stream_bridge import StreamBridge from deerflow.runtime.stream_bridge import StreamBridge
@@ -51,6 +53,8 @@ class RunContext:
event_store: Any | None = field(default=None) event_store: Any | None = field(default=None)
run_events_config: Any | None = field(default=None) run_events_config: Any | None = field(default=None)
thread_store: Any | None = field(default=None) thread_store: Any | None = field(default=None)
follow_up_to_run_id: str | None = field(default=None)
app_config: AppConfig | None = field(default=None)
async def run_agent( async def run_agent(
@@ -75,6 +79,7 @@ async def run_agent(
event_store = ctx.event_store event_store = ctx.event_store
run_events_config = ctx.run_events_config run_events_config = ctx.run_events_config
thread_store = ctx.thread_store thread_store = ctx.thread_store
follow_up_to_run_id = ctx.follow_up_to_run_id
run_id = record.run_id run_id = record.run_id
thread_id = record.thread_id thread_id = record.thread_id
@@ -111,6 +116,22 @@ async def run_agent(
track_token_usage=getattr(run_events_config, "track_token_usage", True), track_token_usage=getattr(run_events_config, "track_token_usage", True),
) )
human_msg = _extract_human_message(graph_input)
if human_msg is not None:
msg_metadata = {}
if follow_up_to_run_id:
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
await event_store.put(
thread_id=thread_id,
run_id=run_id,
event_type="human_message",
category="message",
content=human_msg.model_dump(),
metadata=msg_metadata or None,
)
content = human_msg.content
journal.set_first_human_message(content if isinstance(content, str) else str(content))
# 1. Mark running # 1. Mark running
await run_manager.set_status(run_id, RunStatus.running) await run_manager.set_status(run_id, RunStatus.running)
@@ -144,18 +165,21 @@ async def run_agent(
# 3. Build the agent # 3. Build the agent
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
# Inject runtime context so middlewares can access thread_id # Construct typed context for the agent run.
# (langgraph-cli does this automatically; we must do it manually) # LangGraph's astream(context=...) injects this into Runtime.context
runtime = Runtime(context={"thread_id": thread_id, "run_id": run_id}, store=store) # so middleware/tools can access it via resolve_context().
# If the caller already set a ``context`` key (LangGraph >= 0.6.0 if ctx.app_config is None:
# prefers it over ``configurable`` for thread-level data), make raise RuntimeError("RunContext.app_config is required — Gateway must populate it via get_run_context")
# sure ``thread_id`` is available there too. deer_flow_context = DeerFlowContext(
if "context" in config and isinstance(config["context"], dict): app_config=ctx.app_config,
config["context"].setdefault("thread_id", thread_id) thread_id=thread_id,
config["context"].setdefault("run_id", run_id) )
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
# Inject RunJournal as a LangChain callback handler.
# on_llm_end captures token usage; on_chain_start/end captures lifecycle.
if journal is not None:
config.setdefault("callbacks", []).append(journal)
# Inject RunJournal as a LangChain callback handler. # Inject RunJournal as a LangChain callback handler.
# on_llm_end captures token usage; on_chain_start/end captures lifecycle. # on_llm_end captures token usage; on_chain_start/end captures lifecycle.
@@ -207,7 +231,7 @@ async def run_agent(
if len(lg_modes) == 1 and not stream_subgraphs: if len(lg_modes) == 1 and not stream_subgraphs:
# Single mode, no subgraphs: astream yields raw chunks # Single mode, no subgraphs: astream yields raw chunks
single_mode = lg_modes[0] single_mode = lg_modes[0]
async for chunk in agent.astream(graph_input, config=runnable_config, stream_mode=single_mode): async for chunk in agent.astream(graph_input, config=runnable_config, context=deer_flow_context, stream_mode=single_mode):
if record.abort_event.is_set(): if record.abort_event.is_set():
logger.info("Run %s abort requested — stopping", run_id) logger.info("Run %s abort requested — stopping", run_id)
break break
@@ -218,6 +242,7 @@ async def run_agent(
async for item in agent.astream( async for item in agent.astream(
graph_input, graph_input,
config=runnable_config, config=runnable_config,
context=deer_flow_context,
stream_mode=lg_modes, stream_mode=lg_modes,
subgraphs=stream_subgraphs, subgraphs=stream_subgraphs,
): ):
@@ -23,7 +23,7 @@ from collections.abc import AsyncIterator
from langgraph.store.base import BaseStore from langgraph.store.base import BaseStore
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.runtime.store.provider import POSTGRES_CONN_REQUIRED, POSTGRES_STORE_INSTALL, SQLITE_STORE_INSTALL, ensure_sqlite_parent_dir, resolve_sqlite_conn_str from deerflow.runtime.store.provider import POSTGRES_CONN_REQUIRED, POSTGRES_STORE_INSTALL, SQLITE_STORE_INSTALL, ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -86,7 +86,7 @@ async def _async_store(config) -> AsyncIterator[BaseStore]:
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def make_store() -> AsyncIterator[BaseStore]: async def make_store(app_config: AppConfig) -> AsyncIterator[BaseStore]:
"""Async context manager that yields a Store whose backend matches the """Async context manager that yields a Store whose backend matches the
configured checkpointer. configured checkpointer.
@@ -94,20 +94,18 @@ async def make_store() -> AsyncIterator[BaseStore]:
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer` so :func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer` so
that both singletons always use the same persistence technology:: that both singletons always use the same persistence technology::
async with make_store() as store: async with make_store(app_config) as store:
app.state.store = store app.state.store = store
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
``checkpointer`` section is configured (emits a WARNING in that case). ``checkpointer`` section is configured (emits a WARNING in that case).
""" """
config = get_app_config() if app_config.checkpointer is None:
if config.checkpointer is None:
from langgraph.store.memory import InMemoryStore from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.") logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
yield InMemoryStore() yield InMemoryStore()
return return
async with _async_store(config.checkpointer) as store: async with _async_store(app_config.checkpointer) as store:
yield store yield store
@@ -26,7 +26,7 @@ from collections.abc import Iterator
from langgraph.store.base import BaseStore from langgraph.store.base import BaseStore
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -100,7 +100,7 @@ _store: BaseStore | None = None
_store_ctx = None # open context manager keeping the connection alive _store_ctx = None # open context manager keeping the connection alive
def get_store() -> BaseStore: def get_store(app_config: AppConfig) -> BaseStore:
"""Return the global sync Store singleton, creating it on first call. """Return the global sync Store singleton, creating it on first call.
Returns an :class:`~langgraph.store.memory.InMemoryStore` when no Returns an :class:`~langgraph.store.memory.InMemoryStore` when no
@@ -115,19 +115,10 @@ def get_store() -> BaseStore:
if _store is not None: if _store is not None:
return _store return _store
# Lazily load app config, mirroring the checkpointer singleton pattern so # See matching comment in checkpointer/provider.py: a missing config.yaml
# that tests that set the global checkpointer config explicitly remain isolated. # is a deployment error, not a cue to silently pick InMemoryStore. Only
from deerflow.config.app_config import _app_config # the explicit "no checkpointer section" path falls through to memory.
from deerflow.config.checkpointer_config import get_checkpointer_config config = app_config.checkpointer
config = get_checkpointer_config()
if config is None and _app_config is None:
try:
get_app_config()
except FileNotFoundError:
pass
config = get_checkpointer_config()
if config is None: if config is None:
from langgraph.store.memory import InMemoryStore from langgraph.store.memory import InMemoryStore
@@ -163,26 +154,25 @@ def reset_store() -> None:
@contextlib.contextmanager @contextlib.contextmanager
def store_context() -> Iterator[BaseStore]: def store_context(app_config: AppConfig) -> Iterator[BaseStore]:
"""Sync context manager that yields a Store and cleans up on exit. """Sync context manager that yields a Store and cleans up on exit.
Unlike :func:`get_store`, this does **not** cache the instance each Unlike :func:`get_store`, this does **not** cache the instance each
``with`` block creates and destroys its own connection. Use it in CLI ``with`` block creates and destroys its own connection. Use it in CLI
scripts or tests where you want deterministic cleanup:: scripts or tests where you want deterministic cleanup::
with store_context() as store: with store_context(app_config) as store:
store.put(("threads",), thread_id, {...}) store.put(("threads",), thread_id, {...})
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
checkpointer is configured in *config.yaml*. checkpointer is configured in *config.yaml*.
""" """
config = get_app_config() if app_config.checkpointer is None:
if config.checkpointer is None:
from langgraph.store.memory import InMemoryStore from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.") logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
yield InMemoryStore() yield InMemoryStore()
return return
with _sync_store_cm(config.checkpointer) as store: with _sync_store_cm(app_config.checkpointer) as store:
yield store yield store
@@ -17,7 +17,7 @@ import contextlib
import logging import logging
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from deerflow.config.stream_bridge_config import get_stream_bridge_config from deerflow.config.app_config import AppConfig
from .base import StreamBridge from .base import StreamBridge
@@ -25,14 +25,13 @@ logger = logging.getLogger(__name__)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def make_stream_bridge(config=None) -> AsyncIterator[StreamBridge]: async def make_stream_bridge(app_config: AppConfig) -> AsyncIterator[StreamBridge]:
"""Async context manager that yields a :class:`StreamBridge`. """Async context manager that yields a :class:`StreamBridge`.
Falls back to :class:`MemoryStreamBridge` when no configuration is Falls back to :class:`MemoryStreamBridge` when no ``stream_bridge``
provided and nothing is set globally. section is configured.
""" """
if config is None: config = app_config.stream_bridge
config = get_stream_bridge_config()
if config is None or config.type == "memory": if config is None or config.type == "memory":
from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge
@@ -1,10 +1,14 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING
from deerflow.sandbox.local.local_sandbox import LocalSandbox, PathMapping from deerflow.sandbox.local.local_sandbox import LocalSandbox, PathMapping
from deerflow.sandbox.sandbox import Sandbox from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import SandboxProvider from deerflow.sandbox.sandbox_provider import SandboxProvider
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_singleton: LocalSandbox | None = None _singleton: LocalSandbox | None = None
@@ -13,8 +17,9 @@ _singleton: LocalSandbox | None = None
class LocalSandboxProvider(SandboxProvider): class LocalSandboxProvider(SandboxProvider):
uses_thread_data_mounts = True uses_thread_data_mounts = True
def __init__(self): def __init__(self, app_config: "AppConfig"):
"""Initialize the local sandbox provider with path mappings.""" """Initialize the local sandbox provider with path mappings."""
self._app_config = app_config
self._path_mappings = self._setup_path_mappings() self._path_mappings = self._setup_path_mappings()
def _setup_path_mappings(self) -> list[PathMapping]: def _setup_path_mappings(self) -> list[PathMapping]:
@@ -31,9 +36,7 @@ class LocalSandboxProvider(SandboxProvider):
# Map skills container path to local skills directory # Map skills container path to local skills directory
try: try:
from deerflow.config import get_app_config config = self._app_config
config = get_app_config()
skills_path = config.skills.get_skills_path() skills_path = config.skills.get_skills_path()
container_path = config.skills.container_path container_path = config.skills.container_path
@@ -6,6 +6,7 @@ from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.agents.thread_state import SandboxState, ThreadDataState from deerflow.agents.thread_state import SandboxState, ThreadDataState
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.sandbox import get_sandbox_provider from deerflow.sandbox import get_sandbox_provider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -42,41 +43,35 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
super().__init__() super().__init__()
self._lazy_init = lazy_init self._lazy_init = lazy_init
def _acquire_sandbox(self, thread_id: str) -> str: def _acquire_sandbox(self, thread_id: str, runtime: Runtime[DeerFlowContext]) -> str:
provider = get_sandbox_provider() provider = get_sandbox_provider(runtime.context.app_config)
sandbox_id = provider.acquire(thread_id) sandbox_id = provider.acquire(thread_id)
logger.info(f"Acquiring sandbox {sandbox_id}") logger.info(f"Acquiring sandbox {sandbox_id}")
return sandbox_id return sandbox_id
@override @override
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
# Skip acquisition if lazy_init is enabled # Skip acquisition if lazy_init is enabled
if self._lazy_init: if self._lazy_init:
return super().before_agent(state, runtime) return super().before_agent(state, runtime)
# Eager initialization (original behavior) # Eager initialization (original behavior)
if "sandbox" not in state or state["sandbox"] is None: if "sandbox" not in state or state["sandbox"] is None:
thread_id = (runtime.context or {}).get("thread_id") thread_id = runtime.context.thread_id
if thread_id is None: if not thread_id:
return super().before_agent(state, runtime) return super().before_agent(state, runtime)
sandbox_id = self._acquire_sandbox(thread_id) sandbox_id = self._acquire_sandbox(thread_id, runtime)
logger.info(f"Assigned sandbox {sandbox_id} to thread {thread_id}") logger.info(f"Assigned sandbox {sandbox_id} to thread {thread_id}")
return {"sandbox": {"sandbox_id": sandbox_id}} return {"sandbox": {"sandbox_id": sandbox_id}}
return super().before_agent(state, runtime) return super().before_agent(state, runtime)
@override @override
def after_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: def after_agent(self, state: SandboxMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
sandbox = state.get("sandbox") sandbox = state.get("sandbox")
if sandbox is not None: if sandbox is not None:
sandbox_id = sandbox["sandbox_id"] sandbox_id = sandbox["sandbox_id"]
logger.info(f"Releasing sandbox {sandbox_id}") logger.info(f"Releasing sandbox {sandbox_id}")
get_sandbox_provider().release(sandbox_id) get_sandbox_provider(runtime.context.app_config).release(sandbox_id)
return None
if (runtime.context or {}).get("sandbox_id") is not None:
sandbox_id = runtime.context.get("sandbox_id")
logger.info(f"Releasing sandbox {sandbox_id} from context")
get_sandbox_provider().release(sandbox_id)
return None return None
# No sandbox to release # No sandbox to release
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_class from deerflow.reflection import resolve_class
from deerflow.sandbox.sandbox import Sandbox from deerflow.sandbox.sandbox import Sandbox
@@ -41,23 +41,38 @@ class SandboxProvider(ABC):
_default_sandbox_provider: SandboxProvider | None = None _default_sandbox_provider: SandboxProvider | None = None
def get_sandbox_provider(**kwargs) -> SandboxProvider: def get_sandbox_provider(app_config: AppConfig, **kwargs) -> SandboxProvider:
"""Get the sandbox provider singleton. """Get the sandbox provider singleton.
Returns a cached singleton instance. Use `reset_sandbox_provider()` to clear Returns a cached singleton instance. Use `reset_sandbox_provider()` to clear
the cache, or `shutdown_sandbox_provider()` to properly shutdown and clear. the cache, or `shutdown_sandbox_provider()` to properly shutdown and clear.
Args:
app_config: Application config used the first time the singleton is built.
Ignored on subsequent calls the cached instance is returned
regardless of the config passed.
Returns: Returns:
A sandbox provider instance. A sandbox provider instance.
""" """
global _default_sandbox_provider global _default_sandbox_provider
if _default_sandbox_provider is None: if _default_sandbox_provider is None:
config = get_app_config() cls = resolve_class(app_config.sandbox.use, SandboxProvider)
cls = resolve_class(config.sandbox.use, SandboxProvider) _default_sandbox_provider = cls(app_config=app_config, **kwargs) if _accepts_app_config(cls) else cls(**kwargs)
_default_sandbox_provider = cls(**kwargs)
return _default_sandbox_provider return _default_sandbox_provider
def _accepts_app_config(cls: type) -> bool:
"""Return True when the provider's __init__ accepts an ``app_config`` kwarg."""
import inspect
try:
sig = inspect.signature(cls.__init__)
except (TypeError, ValueError):
return False
return "app_config" in sig.parameters
def reset_sandbox_provider() -> None: def reset_sandbox_provider() -> None:
"""Reset the sandbox provider singleton. """Reset the sandbox provider singleton.
@@ -1,6 +1,6 @@
"""Security helpers for sandbox capability gating.""" """Security helpers for sandbox capability gating."""
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
_LOCAL_SANDBOX_PROVIDER_MARKERS = ( _LOCAL_SANDBOX_PROVIDER_MARKERS = (
"deerflow.sandbox.local:LocalSandboxProvider", "deerflow.sandbox.local:LocalSandboxProvider",
@@ -20,11 +20,8 @@ LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE = (
) )
def uses_local_sandbox_provider(config=None) -> bool: def uses_local_sandbox_provider(config: AppConfig) -> bool:
"""Return True when the active sandbox provider is the host-local provider.""" """Return True when the active sandbox provider is the host-local provider."""
if config is None:
config = get_app_config()
sandbox_cfg = getattr(config, "sandbox", None) sandbox_cfg = getattr(config, "sandbox", None)
sandbox_use = getattr(sandbox_cfg, "use", "") sandbox_use = getattr(sandbox_cfg, "use", "")
if sandbox_use in _LOCAL_SANDBOX_PROVIDER_MARKERS: if sandbox_use in _LOCAL_SANDBOX_PROVIDER_MARKERS:
@@ -32,11 +29,8 @@ def uses_local_sandbox_provider(config=None) -> bool:
return sandbox_use.endswith(":LocalSandboxProvider") and "deerflow.sandbox.local" in sandbox_use return sandbox_use.endswith(":LocalSandboxProvider") and "deerflow.sandbox.local" in sandbox_use
def is_host_bash_allowed(config=None) -> bool: def is_host_bash_allowed(config: AppConfig) -> bool:
"""Return whether host bash execution is explicitly allowed.""" """Return whether host bash execution is explicitly allowed."""
if config is None:
config = get_app_config()
sandbox_cfg = getattr(config, "sandbox", None) sandbox_cfg = getattr(config, "sandbox", None)
if sandbox_cfg is None: if sandbox_cfg is None:
return False return False
+159 -197
View File
@@ -7,7 +7,8 @@ from langchain.tools import ToolRuntime, tool
from langgraph.typing import ContextT from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadDataState, ThreadState from deerflow.agents.thread_state import ThreadDataState, ThreadState
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import resolve_context
from deerflow.config.paths import VIRTUAL_PATH_PREFIX from deerflow.config.paths import VIRTUAL_PATH_PREFIX
from deerflow.sandbox.exceptions import ( from deerflow.sandbox.exceptions import (
SandboxError, SandboxError,
@@ -39,62 +40,43 @@ _DEFAULT_GREP_MAX_RESULTS = 100
_MAX_GREP_MAX_RESULTS = 500 _MAX_GREP_MAX_RESULTS = 500
def _get_skills_container_path() -> str: def _get_skills_container_path(app_config: AppConfig) -> str:
"""Get the skills container path from config, with fallback to default. """Get the skills container path from config, with fallback to default."""
skills_cfg = getattr(app_config, "skills", None)
Result is cached after the first successful config load. If config loading if skills_cfg is None:
fails the default is returned *without* caching so that a later call can
pick up the real value once the config is available.
"""
cached = getattr(_get_skills_container_path, "_cached", None)
if cached is not None:
return cached
try:
from deerflow.config import get_app_config
value = get_app_config().skills.container_path
_get_skills_container_path._cached = value # type: ignore[attr-defined]
return value
except Exception:
return _DEFAULT_SKILLS_CONTAINER_PATH return _DEFAULT_SKILLS_CONTAINER_PATH
return skills_cfg.container_path
def _get_skills_host_path() -> str | None: def _get_skills_host_path(app_config: AppConfig) -> str | None:
"""Get the skills host filesystem path from config. """Get the skills host filesystem path from config.
Returns None if the skills directory does not exist or config cannot be Returns None if the skills directory does not exist or is not configured.
loaded. Only successful lookups are cached; failures are retried on the
next call so that a transiently unavailable skills directory does not
permanently disable skills access.
""" """
cached = getattr(_get_skills_host_path, "_cached", None) skills_cfg = getattr(app_config, "skills", None)
if cached is not None: if skills_cfg is None:
return cached return None
try: try:
from deerflow.config import get_app_config skills_path = skills_cfg.get_skills_path()
config = get_app_config()
skills_path = config.skills.get_skills_path()
if skills_path.exists():
value = str(skills_path)
_get_skills_host_path._cached = value # type: ignore[attr-defined]
return value
except Exception: except Exception:
pass return None
if skills_path.exists():
return str(skills_path)
return None return None
def _is_skills_path(path: str) -> bool: def _is_skills_path(path: str, app_config: AppConfig) -> bool:
"""Check if a path is under the skills container path.""" """Check if a path is under the skills container path."""
skills_prefix = _get_skills_container_path() skills_prefix = _get_skills_container_path(app_config)
return path == skills_prefix or path.startswith(f"{skills_prefix}/") return path == skills_prefix or path.startswith(f"{skills_prefix}/")
def _resolve_skills_path(path: str) -> str: def _resolve_skills_path(path: str, app_config: AppConfig) -> str:
"""Resolve a virtual skills path to a host filesystem path. """Resolve a virtual skills path to a host filesystem path.
Args: Args:
path: Virtual skills path (e.g. /mnt/skills/public/bootstrap/SKILL.md) path: Virtual skills path (e.g. /mnt/skills/public/bootstrap/SKILL.md)
app_config: Resolved application config.
Returns: Returns:
Resolved host path. Resolved host path.
@@ -102,8 +84,8 @@ def _resolve_skills_path(path: str) -> str:
Raises: Raises:
FileNotFoundError: If skills directory is not configured or doesn't exist. FileNotFoundError: If skills directory is not configured or doesn't exist.
""" """
skills_container = _get_skills_container_path() skills_container = _get_skills_container_path(app_config)
skills_host = _get_skills_host_path() skills_host = _get_skills_host_path(app_config)
if skills_host is None: if skills_host is None:
raise FileNotFoundError(f"Skills directory not available for path: {path}") raise FileNotFoundError(f"Skills directory not available for path: {path}")
@@ -119,48 +101,31 @@ def _is_acp_workspace_path(path: str) -> bool:
return path == _ACP_WORKSPACE_VIRTUAL_PATH or path.startswith(f"{_ACP_WORKSPACE_VIRTUAL_PATH}/") return path == _ACP_WORKSPACE_VIRTUAL_PATH or path.startswith(f"{_ACP_WORKSPACE_VIRTUAL_PATH}/")
def _get_custom_mounts(): def _get_custom_mounts(app_config: AppConfig):
"""Get custom volume mounts from sandbox config. """Get custom volume mounts from sandbox config.
Result is cached after the first successful config load. If config loading Only includes mounts whose host_path exists, consistent with
fails an empty list is returned *without* caching so that a later call can ``LocalSandboxProvider._setup_path_mappings()`` which also filters by
pick up the real value once the config is available. ``host_path.exists()``.
""" """
cached = getattr(_get_custom_mounts, "_cached", None) sandbox_cfg = getattr(app_config, "sandbox", None)
if cached is not None: if sandbox_cfg is None or not sandbox_cfg.mounts:
return cached
try:
from pathlib import Path
from deerflow.config import get_app_config
config = get_app_config()
mounts = []
if config.sandbox and config.sandbox.mounts:
# Only include mounts whose host_path exists, consistent with
# LocalSandboxProvider._setup_path_mappings() which also filters
# by host_path.exists().
mounts = [m for m in config.sandbox.mounts if Path(m.host_path).exists()]
_get_custom_mounts._cached = mounts # type: ignore[attr-defined]
return mounts
except Exception:
# If config loading fails, return an empty list without caching so that
# a later call can retry once the config is available.
return [] return []
return [m for m in sandbox_cfg.mounts if Path(m.host_path).exists()]
def _is_custom_mount_path(path: str) -> bool: def _is_custom_mount_path(path: str, app_config: AppConfig) -> bool:
"""Check if path is under a custom mount container_path.""" """Check if path is under a custom mount container_path."""
for mount in _get_custom_mounts(): for mount in _get_custom_mounts(app_config):
if path == mount.container_path or path.startswith(f"{mount.container_path}/"): if path == mount.container_path or path.startswith(f"{mount.container_path}/"):
return True return True
return False return False
def _get_custom_mount_for_path(path: str): def _get_custom_mount_for_path(path: str, app_config: AppConfig):
"""Get the mount config matching this path (longest prefix first).""" """Get the mount config matching this path (longest prefix first)."""
best = None best = None
for mount in _get_custom_mounts(): for mount in _get_custom_mounts(app_config):
if path == mount.container_path or path.startswith(f"{mount.container_path}/"): if path == mount.container_path or path.startswith(f"{mount.container_path}/"):
if best is None or len(mount.container_path) > len(best.container_path): if best is None or len(mount.container_path) > len(best.container_path):
best = mount best = mount
@@ -271,44 +236,40 @@ def _resolve_acp_workspace_path(path: str, thread_id: str | None = None) -> str:
return str(resolved_path) return str(resolved_path)
def _get_mcp_allowed_paths() -> list[str]: def _get_mcp_allowed_paths(app_config: AppConfig) -> list[str]:
"""Get the list of allowed paths from MCP config for file system server.""" """Get the list of allowed paths from MCP config for file system server."""
allowed_paths = [] allowed_paths: list[str] = []
try: extensions_config = getattr(app_config, "extensions", None)
from deerflow.config.extensions_config import get_extensions_config if extensions_config is None:
return allowed_paths
extensions_config = get_extensions_config() for _, server in extensions_config.mcp_servers.items():
if not server.enabled:
continue
for _, server in extensions_config.mcp_servers.items(): # Only check the filesystem server
if not server.enabled: args = server.args or []
continue # Check if args has server-filesystem package
has_filesystem = any("server-filesystem" in arg for arg in args)
# Only check the filesystem server if not has_filesystem:
args = server.args or [] continue
# Check if args has server-filesystem package # Unpack the allowed file system paths in config
has_filesystem = any("server-filesystem" in arg for arg in args) for arg in args:
if not has_filesystem: if not arg.startswith("-") and arg.startswith("/"):
continue allowed_paths.append(arg.rstrip("/") + "/")
# Unpack the allowed file system paths in config
for arg in args:
if not arg.startswith("-") and arg.startswith("/"):
allowed_paths.append(arg.rstrip("/") + "/")
except Exception:
pass
return allowed_paths return allowed_paths
def _get_tool_config_int(name: str, key: str, default: int) -> int: def _get_tool_config_int(app_config: AppConfig, name: str, key: str, default: int) -> int:
try: try:
tool_config = get_app_config().get_tool_config(name) tool_config = app_config.get_tool_config(name)
if tool_config is not None and key in tool_config.model_extra:
value = tool_config.model_extra.get(key)
if isinstance(value, int):
return value
except Exception: except Exception:
pass return default
if tool_config is not None and key in tool_config.model_extra:
value = tool_config.model_extra.get(key)
if isinstance(value, int):
return value
return default return default
@@ -318,23 +279,23 @@ def _clamp_max_results(value: int, *, default: int, upper_bound: int) -> int:
return min(value, upper_bound) return min(value, upper_bound)
def _resolve_max_results(name: str, requested: int, *, default: int, upper_bound: int) -> int: def _resolve_max_results(app_config: AppConfig, name: str, requested: int, *, default: int, upper_bound: int) -> int:
requested_max_results = _clamp_max_results(requested, default=default, upper_bound=upper_bound) requested_max_results = _clamp_max_results(requested, default=default, upper_bound=upper_bound)
configured_max_results = _clamp_max_results( configured_max_results = _clamp_max_results(
_get_tool_config_int(name, "max_results", default), _get_tool_config_int(app_config, name, "max_results", default),
default=default, default=default,
upper_bound=upper_bound, upper_bound=upper_bound,
) )
return min(requested_max_results, configured_max_results) return min(requested_max_results, configured_max_results)
def _resolve_local_read_path(path: str, thread_data: ThreadDataState) -> str: def _resolve_local_read_path(path: str, thread_data: ThreadDataState, app_config: AppConfig) -> str:
validate_local_tool_path(path, thread_data, read_only=True) validate_local_tool_path(path, thread_data, app_config, read_only=True)
if _is_skills_path(path): if _is_skills_path(path, app_config):
return _resolve_skills_path(path) return _resolve_skills_path(path, app_config)
if _is_acp_workspace_path(path): if _is_acp_workspace_path(path):
return _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data)) return _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
return _resolve_and_validate_user_data_path(path, thread_data) return _resolve_and_validate_user_data_path(path, thread_data, app_config)
def _format_glob_results(root_path: str, matches: list[str], truncated: bool) -> str: def _format_glob_results(root_path: str, matches: list[str], truncated: bool) -> str:
@@ -380,7 +341,11 @@ def _join_path_preserving_style(base: str, relative: str) -> str:
return f"{stripped_base}{separator}{normalized_relative}" return f"{stripped_base}{separator}{normalized_relative}"
def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str: def _sanitize_error(
error: Exception,
runtime: "ToolRuntime[ContextT, ThreadState] | None" = None,
app_config: AppConfig | None = None,
) -> str:
"""Sanitize an error message to avoid leaking host filesystem paths. """Sanitize an error message to avoid leaking host filesystem paths.
In local-sandbox mode, resolved host paths in the error string are masked In local-sandbox mode, resolved host paths in the error string are masked
@@ -389,8 +354,12 @@ def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadStat
""" """
msg = f"{type(error).__name__}: {error}" msg = f"{type(error).__name__}: {error}"
if runtime is not None and is_local_sandbox(runtime): if runtime is not None and is_local_sandbox(runtime):
thread_data = get_thread_data(runtime) if app_config is None:
msg = mask_local_paths_in_output(msg, thread_data) ctx = getattr(runtime, "context", None)
app_config = getattr(ctx, "app_config", None)
if app_config is not None:
thread_data = get_thread_data(runtime)
msg = mask_local_paths_in_output(msg, thread_data, app_config)
return msg return msg
@@ -460,7 +429,7 @@ def _thread_actual_to_virtual_mappings(thread_data: ThreadDataState) -> dict[str
return {actual: virtual for virtual, actual in _thread_virtual_to_actual_mappings(thread_data).items()} return {actual: virtual for virtual, actual in _thread_virtual_to_actual_mappings(thread_data).items()}
def mask_local_paths_in_output(output: str, thread_data: ThreadDataState | None) -> str: def mask_local_paths_in_output(output: str, thread_data: ThreadDataState | None, app_config: AppConfig) -> str:
"""Mask host absolute paths from local sandbox output using virtual paths. """Mask host absolute paths from local sandbox output using virtual paths.
Handles user-data paths (per-thread), skills paths, and ACP workspace paths (global). Handles user-data paths (per-thread), skills paths, and ACP workspace paths (global).
@@ -468,8 +437,8 @@ def mask_local_paths_in_output(output: str, thread_data: ThreadDataState | None)
result = output result = output
# Mask skills host paths # Mask skills host paths
skills_host = _get_skills_host_path() skills_host = _get_skills_host_path(app_config)
skills_container = _get_skills_container_path() skills_container = _get_skills_container_path(app_config)
if skills_host: if skills_host:
raw_base = str(Path(skills_host)) raw_base = str(Path(skills_host))
resolved_base = str(Path(skills_host).resolve()) resolved_base = str(Path(skills_host).resolve())
@@ -543,7 +512,13 @@ def _reject_path_traversal(path: str) -> None:
raise PermissionError("Access denied: path traversal detected") raise PermissionError("Access denied: path traversal detected")
def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *, read_only: bool = False) -> None: def validate_local_tool_path(
path: str,
thread_data: ThreadDataState | None,
app_config: AppConfig,
*,
read_only: bool = False,
) -> None:
"""Validate that a virtual path is allowed for local-sandbox access. """Validate that a virtual path is allowed for local-sandbox access.
This function is a security gate it checks whether *path* may be This function is a security gate it checks whether *path* may be
@@ -572,7 +547,7 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
_reject_path_traversal(path) _reject_path_traversal(path)
# Skills paths — read-only access only # Skills paths — read-only access only
if _is_skills_path(path): if _is_skills_path(path, app_config):
if not read_only: if not read_only:
raise PermissionError(f"Write access to skills path is not allowed: {path}") raise PermissionError(f"Write access to skills path is not allowed: {path}")
return return
@@ -588,13 +563,13 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
return return
# Custom mount paths — respect read_only config # Custom mount paths — respect read_only config
if _is_custom_mount_path(path): if _is_custom_mount_path(path, app_config):
mount = _get_custom_mount_for_path(path) mount = _get_custom_mount_for_path(path, app_config)
if mount and mount.read_only and not read_only: if mount and mount.read_only and not read_only:
raise PermissionError(f"Write access to read-only mount is not allowed: {path}") raise PermissionError(f"Write access to read-only mount is not allowed: {path}")
return return
raise PermissionError(f"Only paths under {VIRTUAL_PATH_PREFIX}/, {_get_skills_container_path()}/, {_ACP_WORKSPACE_VIRTUAL_PATH}/, or configured mount paths are allowed") raise PermissionError(f"Only paths under {VIRTUAL_PATH_PREFIX}/, {_get_skills_container_path(app_config)}/, {_ACP_WORKSPACE_VIRTUAL_PATH}/, or configured mount paths are allowed")
def _validate_resolved_user_data_path(resolved: Path, thread_data: ThreadDataState) -> None: def _validate_resolved_user_data_path(resolved: Path, thread_data: ThreadDataState) -> None:
@@ -625,18 +600,23 @@ def _validate_resolved_user_data_path(resolved: Path, thread_data: ThreadDataSta
raise PermissionError("Access denied: path traversal detected") raise PermissionError("Access denied: path traversal detected")
def _resolve_and_validate_user_data_path(path: str, thread_data: ThreadDataState) -> str: def _resolve_and_validate_user_data_path(path: str, thread_data: ThreadDataState, app_config: AppConfig) -> str:
"""Resolve a /mnt/user-data virtual path and validate it stays in bounds. """Resolve a /mnt/user-data virtual path and validate it stays in bounds.
Returns the resolved host path string. Returns the resolved host path string.
``app_config`` is accepted for signature symmetry with the other resolver
helpers; the user-data resolution path itself is fully derivable from
``thread_data``.
""" """
_ = app_config # noqa: F841 — kept for interface symmetry with sibling resolvers
resolved_str = replace_virtual_path(path, thread_data) resolved_str = replace_virtual_path(path, thread_data)
resolved = Path(resolved_str).resolve() resolved = Path(resolved_str).resolve()
_validate_resolved_user_data_path(resolved, thread_data) _validate_resolved_user_data_path(resolved, thread_data)
return str(resolved) return str(resolved)
def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState | None) -> None: def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState | None, app_config: AppConfig) -> None:
"""Validate absolute paths in local-sandbox bash commands. """Validate absolute paths in local-sandbox bash commands.
This validation is only a best-effort guard for the explicit This validation is only a best-effort guard for the explicit
@@ -660,7 +640,7 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
raise PermissionError(f"Unsafe file:// URL in command: {file_url_match.group()}. Use paths under {VIRTUAL_PATH_PREFIX}") raise PermissionError(f"Unsafe file:// URL in command: {file_url_match.group()}. Use paths under {VIRTUAL_PATH_PREFIX}")
unsafe_paths: list[str] = [] unsafe_paths: list[str] = []
allowed_paths = _get_mcp_allowed_paths() allowed_paths = _get_mcp_allowed_paths(app_config)
for absolute_path in _ABSOLUTE_PATH_PATTERN.findall(command): for absolute_path in _ABSOLUTE_PATH_PATTERN.findall(command):
# Check for MCP filesystem server allowed paths # Check for MCP filesystem server allowed paths
@@ -673,7 +653,7 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
continue continue
# Allow skills container path (resolved by tools.py before passing to sandbox) # Allow skills container path (resolved by tools.py before passing to sandbox)
if _is_skills_path(absolute_path): if _is_skills_path(absolute_path, app_config):
_reject_path_traversal(absolute_path) _reject_path_traversal(absolute_path)
continue continue
@@ -683,7 +663,7 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
continue continue
# Allow custom mount container paths # Allow custom mount container paths
if _is_custom_mount_path(absolute_path): if _is_custom_mount_path(absolute_path, app_config):
_reject_path_traversal(absolute_path) _reject_path_traversal(absolute_path)
continue continue
@@ -697,12 +677,13 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
raise PermissionError(f"Unsafe absolute paths in command: {unsafe}. Use paths under {VIRTUAL_PATH_PREFIX}") raise PermissionError(f"Unsafe absolute paths in command: {unsafe}. Use paths under {VIRTUAL_PATH_PREFIX}")
def replace_virtual_paths_in_command(command: str, thread_data: ThreadDataState | None) -> str: def replace_virtual_paths_in_command(command: str, thread_data: ThreadDataState | None, app_config: AppConfig) -> str:
"""Replace all virtual paths (/mnt/user-data, /mnt/skills, /mnt/acp-workspace) in a command string. """Replace all virtual paths (/mnt/user-data, /mnt/skills, /mnt/acp-workspace) in a command string.
Args: Args:
command: The command string that may contain virtual paths. command: The command string that may contain virtual paths.
thread_data: The thread data containing actual paths. thread_data: The thread data containing actual paths.
app_config: Resolved application config.
Returns: Returns:
The command with all virtual paths replaced. The command with all virtual paths replaced.
@@ -710,13 +691,13 @@ def replace_virtual_paths_in_command(command: str, thread_data: ThreadDataState
result = command result = command
# Replace skills paths # Replace skills paths
skills_container = _get_skills_container_path() skills_container = _get_skills_container_path(app_config)
skills_host = _get_skills_host_path() skills_host = _get_skills_host_path(app_config)
if skills_host and skills_container in result: if skills_host and skills_container in result:
skills_pattern = re.compile(rf"{re.escape(skills_container)}(/[^\s\"';&|<>()]*)?") skills_pattern = re.compile(rf"{re.escape(skills_container)}(/[^\s\"';&|<>()]*)?")
def replace_skills_match(match: re.Match) -> str: def replace_skills_match(match: re.Match) -> str:
return _resolve_skills_path(match.group(0)) return _resolve_skills_path(match.group(0), app_config)
result = skills_pattern.sub(replace_skills_match, result) result = skills_pattern.sub(replace_skills_match, result)
@@ -806,12 +787,10 @@ def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = No
sandbox_id = sandbox_state.get("sandbox_id") sandbox_id = sandbox_state.get("sandbox_id")
if sandbox_id is None: if sandbox_id is None:
raise SandboxRuntimeError("Sandbox ID not found in state") raise SandboxRuntimeError("Sandbox ID not found in state")
sandbox = get_sandbox_provider().get(sandbox_id) sandbox = get_sandbox_provider(resolve_context(runtime).app_config).get(sandbox_id)
if sandbox is None: if sandbox is None:
raise SandboxNotFoundError(f"Sandbox with ID '{sandbox_id}' not found", sandbox_id=sandbox_id) raise SandboxNotFoundError(f"Sandbox with ID '{sandbox_id}' not found", sandbox_id=sandbox_id)
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for downstream use
return sandbox return sandbox
@@ -839,26 +818,24 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
if runtime.state is None: if runtime.state is None:
raise SandboxRuntimeError("Tool runtime state not available") raise SandboxRuntimeError("Tool runtime state not available")
app_config = runtime.context.app_config
# Check if sandbox already exists in state # Check if sandbox already exists in state
sandbox_state = runtime.state.get("sandbox") sandbox_state = runtime.state.get("sandbox")
if sandbox_state is not None: if sandbox_state is not None:
sandbox_id = sandbox_state.get("sandbox_id") sandbox_id = sandbox_state.get("sandbox_id")
if sandbox_id is not None: if sandbox_id is not None:
sandbox = get_sandbox_provider().get(sandbox_id) sandbox = get_sandbox_provider(app_config).get(sandbox_id)
if sandbox is not None: if sandbox is not None:
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
return sandbox return sandbox
# Sandbox was released, fall through to acquire new one # Sandbox was released, fall through to acquire new one
# Lazy acquisition: get thread_id and acquire sandbox # Lazy acquisition: get thread_id and acquire sandbox
thread_id = runtime.context.get("thread_id") if runtime.context else None thread_id = runtime.context.thread_id
if thread_id is None: if not thread_id:
thread_id = runtime.config.get("configurable", {}).get("thread_id") if runtime.config else None
if thread_id is None:
raise SandboxRuntimeError("Thread ID not available in runtime context") raise SandboxRuntimeError("Thread ID not available in runtime context")
provider = get_sandbox_provider() provider = get_sandbox_provider(app_config)
sandbox_id = provider.acquire(thread_id) sandbox_id = provider.acquire(thread_id)
# Update runtime state - this persists across tool calls # Update runtime state - this persists across tool calls
@@ -869,8 +846,6 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
if sandbox is None: if sandbox is None:
raise SandboxNotFoundError("Sandbox not found after acquisition", sandbox_id=sandbox_id) raise SandboxNotFoundError("Sandbox not found after acquisition", sandbox_id=sandbox_id)
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
return sandbox return sandbox
@@ -1000,40 +975,29 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
description: Explain why you are running this command in short words. ALWAYS PROVIDE THIS PARAMETER FIRST. description: Explain why you are running this command in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
command: The bash command to execute. Always use absolute paths for files and directories. command: The bash command to execute. Always use absolute paths for files and directories.
""" """
app_config = resolve_context(runtime).app_config
try: try:
sandbox = ensure_sandbox_initialized(runtime) sandbox = ensure_sandbox_initialized(runtime)
sandbox_cfg = app_config.sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
if is_local_sandbox(runtime): if is_local_sandbox(runtime):
if not is_host_bash_allowed(): if not is_host_bash_allowed(app_config):
return f"Error: {LOCAL_HOST_BASH_DISABLED_MESSAGE}" return f"Error: {LOCAL_HOST_BASH_DISABLED_MESSAGE}"
ensure_thread_directories_exist(runtime) ensure_thread_directories_exist(runtime)
thread_data = get_thread_data(runtime) thread_data = get_thread_data(runtime)
validate_local_bash_command_paths(command, thread_data) validate_local_bash_command_paths(command, thread_data, app_config)
command = replace_virtual_paths_in_command(command, thread_data) command = replace_virtual_paths_in_command(command, thread_data, app_config)
command = _apply_cwd_prefix(command, thread_data) command = _apply_cwd_prefix(command, thread_data)
output = sandbox.execute_command(command) output = sandbox.execute_command(command)
try: return _truncate_bash_output(mask_local_paths_in_output(output, thread_data, app_config), max_chars)
from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
return _truncate_bash_output(mask_local_paths_in_output(output, thread_data), max_chars)
ensure_thread_directories_exist(runtime) ensure_thread_directories_exist(runtime)
try:
from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
return _truncate_bash_output(sandbox.execute_command(command), max_chars) return _truncate_bash_output(sandbox.execute_command(command), max_chars)
except SandboxError as e: except SandboxError as e:
return f"Error: {e}" return f"Error: {e}"
except PermissionError as e: except PermissionError as e:
return f"Error: {e}" return f"Error: {e}"
except Exception as e: except Exception as e:
return f"Error: Unexpected error executing command: {_sanitize_error(e, runtime)}" return f"Error: Unexpected error executing command: {_sanitize_error(e, runtime, app_config)}"
@tool("ls", parse_docstring=True) @tool("ls", parse_docstring=True)
@@ -1044,6 +1008,7 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
description: Explain why you are listing this directory in short words. ALWAYS PROVIDE THIS PARAMETER FIRST. description: Explain why you are listing this directory in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
path: The **absolute** path to the directory to list. path: The **absolute** path to the directory to list.
""" """
app_config = resolve_context(runtime).app_config
try: try:
sandbox = ensure_sandbox_initialized(runtime) sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime) ensure_thread_directories_exist(runtime)
@@ -1051,13 +1016,13 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
thread_data = None thread_data = None
if is_local_sandbox(runtime): if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime) thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data, read_only=True) validate_local_tool_path(path, thread_data, app_config, read_only=True)
if _is_skills_path(path): if _is_skills_path(path, app_config):
path = _resolve_skills_path(path) path = _resolve_skills_path(path, app_config)
elif _is_acp_workspace_path(path): elif _is_acp_workspace_path(path):
path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data)) path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
elif not _is_custom_mount_path(path): elif not _is_custom_mount_path(path, app_config):
path = _resolve_and_validate_user_data_path(path, thread_data) path = _resolve_and_validate_user_data_path(path, thread_data, app_config)
# Custom mount paths are resolved by LocalSandbox._resolve_path() # Custom mount paths are resolved by LocalSandbox._resolve_path()
children = sandbox.list_dir(path) children = sandbox.list_dir(path)
if not children: if not children:
@@ -1065,13 +1030,8 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
output = "\n".join(children) output = "\n".join(children)
if thread_data is not None: if thread_data is not None:
output = mask_local_paths_in_output(output, thread_data) output = mask_local_paths_in_output(output, thread_data)
try: sandbox_cfg = app_config.sandbox
from deerflow.config.app_config import get_app_config max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
sandbox_cfg = get_app_config().sandbox
max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
return _truncate_ls_output(output, max_chars) return _truncate_ls_output(output, max_chars)
except SandboxError as e: except SandboxError as e:
return f"Error: {e}" return f"Error: {e}"
@@ -1080,7 +1040,7 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
except PermissionError: except PermissionError:
return f"Error: Permission denied: {requested_path}" return f"Error: Permission denied: {requested_path}"
except Exception as e: except Exception as e:
return f"Error: Unexpected error listing directory: {_sanitize_error(e, runtime)}" return f"Error: Unexpected error listing directory: {_sanitize_error(e, runtime, app_config)}"
@tool("glob", parse_docstring=True) @tool("glob", parse_docstring=True)
@@ -1101,11 +1061,13 @@ def glob_tool(
include_dirs: Whether matching directories should also be returned. Default is False. include_dirs: Whether matching directories should also be returned. Default is False.
max_results: Maximum number of paths to return. Default is 200. max_results: Maximum number of paths to return. Default is 200.
""" """
app_config = resolve_context(runtime).app_config
try: try:
sandbox = ensure_sandbox_initialized(runtime) sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime) ensure_thread_directories_exist(runtime)
requested_path = path requested_path = path
effective_max_results = _resolve_max_results( effective_max_results = _resolve_max_results(
app_config,
"glob", "glob",
max_results, max_results,
default=_DEFAULT_GLOB_MAX_RESULTS, default=_DEFAULT_GLOB_MAX_RESULTS,
@@ -1116,10 +1078,10 @@ def glob_tool(
thread_data = get_thread_data(runtime) thread_data = get_thread_data(runtime)
if thread_data is None: if thread_data is None:
raise SandboxRuntimeError("Thread data not available for local sandbox") raise SandboxRuntimeError("Thread data not available for local sandbox")
path = _resolve_local_read_path(path, thread_data) path = _resolve_local_read_path(path, thread_data, app_config)
matches, truncated = sandbox.glob(path, pattern, include_dirs=include_dirs, max_results=effective_max_results) matches, truncated = sandbox.glob(path, pattern, include_dirs=include_dirs, max_results=effective_max_results)
if thread_data is not None: if thread_data is not None:
matches = [mask_local_paths_in_output(match, thread_data) for match in matches] matches = [mask_local_paths_in_output(match, thread_data, app_config) for match in matches]
return _format_glob_results(requested_path, matches, truncated) return _format_glob_results(requested_path, matches, truncated)
except SandboxError as e: except SandboxError as e:
return f"Error: {e}" return f"Error: {e}"
@@ -1130,7 +1092,7 @@ def glob_tool(
except PermissionError: except PermissionError:
return f"Error: Permission denied: {requested_path}" return f"Error: Permission denied: {requested_path}"
except Exception as e: except Exception as e:
return f"Error: Unexpected error searching paths: {_sanitize_error(e, runtime)}" return f"Error: Unexpected error searching paths: {_sanitize_error(e, runtime, app_config)}"
@tool("grep", parse_docstring=True) @tool("grep", parse_docstring=True)
@@ -1155,11 +1117,13 @@ def grep_tool(
case_sensitive: Whether matching is case-sensitive. Default is False. case_sensitive: Whether matching is case-sensitive. Default is False.
max_results: Maximum number of matching lines to return. Default is 100. max_results: Maximum number of matching lines to return. Default is 100.
""" """
app_config = resolve_context(runtime).app_config
try: try:
sandbox = ensure_sandbox_initialized(runtime) sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime) ensure_thread_directories_exist(runtime)
requested_path = path requested_path = path
effective_max_results = _resolve_max_results( effective_max_results = _resolve_max_results(
app_config,
"grep", "grep",
max_results, max_results,
default=_DEFAULT_GREP_MAX_RESULTS, default=_DEFAULT_GREP_MAX_RESULTS,
@@ -1170,7 +1134,7 @@ def grep_tool(
thread_data = get_thread_data(runtime) thread_data = get_thread_data(runtime)
if thread_data is None: if thread_data is None:
raise SandboxRuntimeError("Thread data not available for local sandbox") raise SandboxRuntimeError("Thread data not available for local sandbox")
path = _resolve_local_read_path(path, thread_data) path = _resolve_local_read_path(path, thread_data, app_config)
matches, truncated = sandbox.grep( matches, truncated = sandbox.grep(
path, path,
pattern, pattern,
@@ -1182,7 +1146,7 @@ def grep_tool(
if thread_data is not None: if thread_data is not None:
matches = [ matches = [
GrepMatch( GrepMatch(
path=mask_local_paths_in_output(match.path, thread_data), path=mask_local_paths_in_output(match.path, thread_data, app_config),
line_number=match.line_number, line_number=match.line_number,
line=match.line, line=match.line,
) )
@@ -1200,7 +1164,7 @@ def grep_tool(
except PermissionError: except PermissionError:
return f"Error: Permission denied: {requested_path}" return f"Error: Permission denied: {requested_path}"
except Exception as e: except Exception as e:
return f"Error: Unexpected error searching file contents: {_sanitize_error(e, runtime)}" return f"Error: Unexpected error searching file contents: {_sanitize_error(e, runtime, app_config)}"
@tool("read_file", parse_docstring=True) @tool("read_file", parse_docstring=True)
@@ -1219,32 +1183,28 @@ def read_file_tool(
start_line: Optional starting line number (1-indexed, inclusive). Use with end_line to read a specific range. start_line: Optional starting line number (1-indexed, inclusive). Use with end_line to read a specific range.
end_line: Optional ending line number (1-indexed, inclusive). Use with start_line to read a specific range. end_line: Optional ending line number (1-indexed, inclusive). Use with start_line to read a specific range.
""" """
app_config = resolve_context(runtime).app_config
try: try:
sandbox = ensure_sandbox_initialized(runtime) sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime) ensure_thread_directories_exist(runtime)
requested_path = path requested_path = path
if is_local_sandbox(runtime): if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime) thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data, read_only=True) validate_local_tool_path(path, thread_data, app_config, read_only=True)
if _is_skills_path(path): if _is_skills_path(path, app_config):
path = _resolve_skills_path(path) path = _resolve_skills_path(path, app_config)
elif _is_acp_workspace_path(path): elif _is_acp_workspace_path(path):
path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data)) path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
elif not _is_custom_mount_path(path): elif not _is_custom_mount_path(path, app_config):
path = _resolve_and_validate_user_data_path(path, thread_data) path = _resolve_and_validate_user_data_path(path, thread_data, app_config)
# Custom mount paths are resolved by LocalSandbox._resolve_path() # Custom mount paths are resolved by LocalSandbox._resolve_path()
content = sandbox.read_file(path) content = sandbox.read_file(path)
if not content: if not content:
return "(empty)" return "(empty)"
if start_line is not None and end_line is not None: if start_line is not None and end_line is not None:
content = "\n".join(content.splitlines()[start_line - 1 : end_line]) content = "\n".join(content.splitlines()[start_line - 1 : end_line])
try: sandbox_cfg = app_config.sandbox
from deerflow.config.app_config import get_app_config max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
sandbox_cfg = get_app_config().sandbox
max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
except Exception:
max_chars = 50000
return _truncate_read_file_output(content, max_chars) return _truncate_read_file_output(content, max_chars)
except SandboxError as e: except SandboxError as e:
return f"Error: {e}" return f"Error: {e}"
@@ -1255,7 +1215,7 @@ def read_file_tool(
except IsADirectoryError: except IsADirectoryError:
return f"Error: Path is a directory, not a file: {requested_path}" return f"Error: Path is a directory, not a file: {requested_path}"
except Exception as e: except Exception as e:
return f"Error: Unexpected error reading file: {_sanitize_error(e, runtime)}" return f"Error: Unexpected error reading file: {_sanitize_error(e, runtime, app_config)}"
@tool("write_file", parse_docstring=True) @tool("write_file", parse_docstring=True)
@@ -1273,15 +1233,16 @@ def write_file_tool(
path: The **absolute** path to the file to write to. ALWAYS PROVIDE THIS PARAMETER SECOND. path: The **absolute** path to the file to write to. ALWAYS PROVIDE THIS PARAMETER SECOND.
content: The content to write to the file. ALWAYS PROVIDE THIS PARAMETER THIRD. content: The content to write to the file. ALWAYS PROVIDE THIS PARAMETER THIRD.
""" """
app_config = resolve_context(runtime).app_config
try: try:
sandbox = ensure_sandbox_initialized(runtime) sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime) ensure_thread_directories_exist(runtime)
requested_path = path requested_path = path
if is_local_sandbox(runtime): if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime) thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data) validate_local_tool_path(path, thread_data, app_config)
if not _is_custom_mount_path(path): if not _is_custom_mount_path(path, app_config):
path = _resolve_and_validate_user_data_path(path, thread_data) path = _resolve_and_validate_user_data_path(path, thread_data, app_config)
# Custom mount paths are resolved by LocalSandbox._resolve_path() # Custom mount paths are resolved by LocalSandbox._resolve_path()
with get_file_operation_lock(sandbox, path): with get_file_operation_lock(sandbox, path):
sandbox.write_file(path, content, append) sandbox.write_file(path, content, append)
@@ -1293,9 +1254,9 @@ def write_file_tool(
except IsADirectoryError: except IsADirectoryError:
return f"Error: Path is a directory, not a file: {requested_path}" return f"Error: Path is a directory, not a file: {requested_path}"
except OSError as e: except OSError as e:
return f"Error: Failed to write file '{requested_path}': {_sanitize_error(e, runtime)}" return f"Error: Failed to write file '{requested_path}': {_sanitize_error(e, runtime, app_config)}"
except Exception as e: except Exception as e:
return f"Error: Unexpected error writing file: {_sanitize_error(e, runtime)}" return f"Error: Unexpected error writing file: {_sanitize_error(e, runtime, app_config)}"
@tool("str_replace", parse_docstring=True) @tool("str_replace", parse_docstring=True)
@@ -1317,15 +1278,16 @@ def str_replace_tool(
new_str: The new substring. ALWAYS PROVIDE THIS PARAMETER FOURTH. new_str: The new substring. ALWAYS PROVIDE THIS PARAMETER FOURTH.
replace_all: Whether to replace all occurrences of the substring. If False, only the first occurrence will be replaced. Default is False. replace_all: Whether to replace all occurrences of the substring. If False, only the first occurrence will be replaced. Default is False.
""" """
app_config = resolve_context(runtime).app_config
try: try:
sandbox = ensure_sandbox_initialized(runtime) sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime) ensure_thread_directories_exist(runtime)
requested_path = path requested_path = path
if is_local_sandbox(runtime): if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime) thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data) validate_local_tool_path(path, thread_data, app_config)
if not _is_custom_mount_path(path): if not _is_custom_mount_path(path, app_config):
path = _resolve_and_validate_user_data_path(path, thread_data) path = _resolve_and_validate_user_data_path(path, thread_data, app_config)
# Custom mount paths are resolved by LocalSandbox._resolve_path() # Custom mount paths are resolved by LocalSandbox._resolve_path()
with get_file_operation_lock(sandbox, path): with get_file_operation_lock(sandbox, path):
content = sandbox.read_file(path) content = sandbox.read_file(path)
@@ -1346,4 +1308,4 @@ def str_replace_tool(
except PermissionError: except PermissionError:
return f"Error: Permission denied accessing file: {requested_path}" return f"Error: Permission denied accessing file: {requested_path}"
except Exception as e: except Exception as e:
return f"Error: Unexpected error replacing string: {_sanitize_error(e, runtime)}" return f"Error: Unexpected error replacing string: {_sanitize_error(e, runtime, app_config)}"
@@ -1,10 +1,14 @@
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING
from .parser import parse_skill_file from .parser import parse_skill_file
from .types import Skill from .types import Skill
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -22,7 +26,12 @@ def get_skills_root_path() -> Path:
return skills_dir return skills_dir
def load_skills(skills_path: Path | None = None, use_config: bool = True, enabled_only: bool = False) -> list[Skill]: def load_skills(
app_config: "AppConfig | None" = None,
*,
skills_path: Path | None = None,
enabled_only: bool = False,
) -> list[Skill]:
""" """
Load all skills from the skills directory. Load all skills from the skills directory.
@@ -30,25 +39,19 @@ def load_skills(skills_path: Path | None = None, use_config: bool = True, enable
to extract metadata. The enabled state is determined by the skills_state_config.json file. to extract metadata. The enabled state is determined by the skills_state_config.json file.
Args: Args:
skills_path: Optional custom path to skills directory. app_config: Application config used to resolve the configured skills
If not provided and use_config is True, uses path from config. directory. Ignored when ``skills_path`` is supplied.
Otherwise defaults to deer-flow/skills skills_path: Explicit override for the skills directory. When both
use_config: Whether to load skills path from config (default: True) ``skills_path`` and ``app_config`` are omitted the
default repository layout is used (``deer-flow/skills``).
enabled_only: If True, only return enabled skills (default: False) enabled_only: If True, only return enabled skills (default: False)
Returns: Returns:
List of Skill objects, sorted by name List of Skill objects, sorted by name
""" """
if skills_path is None: if skills_path is None:
if use_config: if app_config is not None:
try: skills_path = app_config.skills.get_skills_path()
from deerflow.config import get_app_config
config = get_app_config()
skills_path = config.skills.get_skills_path()
except Exception:
# Fallback to default if config fails
skills_path = get_skills_root_path()
else: else:
skills_path = get_skills_root_path() skills_path = get_skills_root_path()
@@ -9,7 +9,7 @@ from datetime import UTC, datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.skills.loader import load_skills from deerflow.skills.loader import load_skills
from deerflow.skills.validation import _validate_skill_frontmatter from deerflow.skills.validation import _validate_skill_frontmatter
@@ -20,16 +20,17 @@ ALLOWED_SUPPORT_SUBDIRS = {"references", "templates", "scripts", "assets"}
_SKILL_NAME_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$") _SKILL_NAME_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$")
def get_skills_root_dir() -> Path: def get_skills_root_dir(app_config: AppConfig) -> Path:
return get_app_config().skills.get_skills_path() """Return the configured skills root."""
return app_config.skills.get_skills_path()
def get_public_skills_dir() -> Path: def get_public_skills_dir(app_config: AppConfig) -> Path:
return get_skills_root_dir() / "public" return get_skills_root_dir(app_config) / "public"
def get_custom_skills_dir() -> Path: def get_custom_skills_dir(app_config: AppConfig) -> Path:
path = get_skills_root_dir() / "custom" path = get_skills_root_dir(app_config) / "custom"
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
return path return path
@@ -43,46 +44,46 @@ def validate_skill_name(name: str) -> str:
return normalized return normalized
def get_custom_skill_dir(name: str) -> Path: def get_custom_skill_dir(name: str, app_config: AppConfig) -> Path:
return get_custom_skills_dir() / validate_skill_name(name) return get_custom_skills_dir(app_config) / validate_skill_name(name)
def get_custom_skill_file(name: str) -> Path: def get_custom_skill_file(name: str, app_config: AppConfig) -> Path:
return get_custom_skill_dir(name) / SKILL_FILE_NAME return get_custom_skill_dir(name, app_config) / SKILL_FILE_NAME
def get_custom_skill_history_dir() -> Path: def get_custom_skill_history_dir(app_config: AppConfig) -> Path:
path = get_custom_skills_dir() / HISTORY_DIR_NAME path = get_custom_skills_dir(app_config) / HISTORY_DIR_NAME
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
return path return path
def get_skill_history_file(name: str) -> Path: def get_skill_history_file(name: str, app_config: AppConfig) -> Path:
return get_custom_skill_history_dir() / f"{validate_skill_name(name)}.jsonl" return get_custom_skill_history_dir(app_config) / f"{validate_skill_name(name)}.jsonl"
def get_public_skill_dir(name: str) -> Path: def get_public_skill_dir(name: str, app_config: AppConfig) -> Path:
return get_public_skills_dir() / validate_skill_name(name) return get_public_skills_dir(app_config) / validate_skill_name(name)
def custom_skill_exists(name: str) -> bool: def custom_skill_exists(name: str, app_config: AppConfig) -> bool:
return get_custom_skill_file(name).exists() return get_custom_skill_file(name, app_config).exists()
def public_skill_exists(name: str) -> bool: def public_skill_exists(name: str, app_config: AppConfig) -> bool:
return (get_public_skill_dir(name) / SKILL_FILE_NAME).exists() return (get_public_skill_dir(name, app_config) / SKILL_FILE_NAME).exists()
def ensure_custom_skill_is_editable(name: str) -> None: def ensure_custom_skill_is_editable(name: str, app_config: AppConfig) -> None:
if custom_skill_exists(name): if custom_skill_exists(name, app_config):
return return
if public_skill_exists(name): if public_skill_exists(name, app_config):
raise ValueError(f"'{name}' is a built-in skill. To customise it, create a new skill with the same name under skills/custom/.") raise ValueError(f"'{name}' is a built-in skill. To customise it, create a new skill with the same name under skills/custom/.")
raise FileNotFoundError(f"Custom skill '{name}' not found.") raise FileNotFoundError(f"Custom skill '{name}' not found.")
def ensure_safe_support_path(name: str, relative_path: str) -> Path: def ensure_safe_support_path(name: str, relative_path: str, app_config: AppConfig) -> Path:
skill_dir = get_custom_skill_dir(name).resolve() skill_dir = get_custom_skill_dir(name, app_config).resolve()
if not relative_path or relative_path.endswith("/"): if not relative_path or relative_path.endswith("/"):
raise ValueError("Supporting file path must include a filename.") raise ValueError("Supporting file path must include a filename.")
relative = Path(relative_path) relative = Path(relative_path)
@@ -124,8 +125,8 @@ def atomic_write(path: Path, content: str) -> None:
tmp_path.replace(path) tmp_path.replace(path)
def append_history(name: str, record: dict[str, Any]) -> None: def append_history(name: str, record: dict[str, Any], app_config: AppConfig) -> None:
history_path = get_skill_history_file(name) history_path = get_skill_history_file(name, app_config)
history_path.parent.mkdir(parents=True, exist_ok=True) history_path.parent.mkdir(parents=True, exist_ok=True)
payload = { payload = {
"ts": datetime.now(UTC).isoformat(), "ts": datetime.now(UTC).isoformat(),
@@ -136,8 +137,8 @@ def append_history(name: str, record: dict[str, Any]) -> None:
f.write("\n") f.write("\n")
def read_history(name: str) -> list[dict[str, Any]]: def read_history(name: str, app_config: AppConfig) -> list[dict[str, Any]]:
history_path = get_skill_history_file(name) history_path = get_skill_history_file(name, app_config)
if not history_path.exists(): if not history_path.exists():
return [] return []
records: list[dict[str, Any]] = [] records: list[dict[str, Any]] = []
@@ -148,12 +149,12 @@ def read_history(name: str) -> list[dict[str, Any]]:
return records return records
def list_custom_skills() -> list: def list_custom_skills(app_config: AppConfig) -> list:
return [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"] return [skill for skill in load_skills(app_config, enabled_only=False) if skill.category == "custom"]
def read_custom_skill_content(name: str) -> str: def read_custom_skill_content(name: str, app_config: AppConfig) -> str:
skill_file = get_custom_skill_file(name) skill_file = get_custom_skill_file(name, app_config)
if not skill_file.exists(): if not skill_file.exists():
raise FileNotFoundError(f"Custom skill '{name}' not found.") raise FileNotFoundError(f"Custom skill '{name}' not found.")
return skill_file.read_text(encoding="utf-8") return skill_file.read_text(encoding="utf-8")
@@ -7,7 +7,7 @@ import logging
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -35,7 +35,7 @@ def _extract_json_object(raw: str) -> dict | None:
return None return None
async def scan_skill_content(content: str, *, executable: bool = False, location: str = "SKILL.md") -> ScanResult: async def scan_skill_content(app_config: AppConfig, content: str, *, executable: bool = False, location: str = "SKILL.md") -> ScanResult:
"""Screen skill content before it is written to disk.""" """Screen skill content before it is written to disk."""
rubric = ( rubric = (
"You are a security reviewer for AI agent skills. " "You are a security reviewer for AI agent skills. "
@@ -47,9 +47,12 @@ async def scan_skill_content(content: str, *, executable: bool = False, location
prompt = f"Location: {location}\nExecutable: {str(executable).lower()}\n\nReview this content:\n-----\n{content}\n-----" prompt = f"Location: {location}\nExecutable: {str(executable).lower()}\n\nReview this content:\n-----\n{content}\n-----"
try: try:
config = get_app_config() model_name = app_config.skill_evolution.moderation_model_name
model_name = config.skill_evolution.moderation_model_name model = (
model = create_chat_model(name=model_name, thinking_enabled=False) if model_name else create_chat_model(thinking_enabled=False) create_chat_model(name=model_name, thinking_enabled=False, app_config=app_config)
if model_name
else create_chat_model(thinking_enabled=False, app_config=app_config)
)
response = await model.ainvoke( response = await model.ainvoke(
[ [
{"role": "system", "content": rubric}, {"role": "system", "content": rubric},
@@ -17,6 +17,7 @@ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadState from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadState
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model from deerflow.models import create_chat_model
from deerflow.subagents.config import SubagentConfig from deerflow.subagents.config import SubagentConfig
@@ -132,24 +133,16 @@ class SubagentExecutor:
self, self,
config: SubagentConfig, config: SubagentConfig,
tools: list[BaseTool], tools: list[BaseTool],
app_config: AppConfig,
parent_model: str | None = None, parent_model: str | None = None,
sandbox_state: SandboxState | None = None, sandbox_state: SandboxState | None = None,
thread_data: ThreadDataState | None = None, thread_data: ThreadDataState | None = None,
thread_id: str | None = None, thread_id: str | None = None,
trace_id: str | None = None, trace_id: str | None = None,
): ):
"""Initialize the executor. """Initialize the executor."""
Args:
config: Subagent configuration.
tools: List of all available tools (will be filtered).
parent_model: The parent agent's model name for inheritance.
sandbox_state: Sandbox state from parent agent.
thread_data: Thread data from parent agent.
thread_id: Thread ID for sandbox operations.
trace_id: Trace ID from parent for distributed tracing.
"""
self.config = config self.config = config
self.app_config = app_config
self.parent_model = parent_model self.parent_model = parent_model
self.sandbox_state = sandbox_state self.sandbox_state = sandbox_state
self.thread_data = thread_data self.thread_data = thread_data
@@ -169,7 +162,7 @@ class SubagentExecutor:
def _create_agent(self): def _create_agent(self):
"""Create the agent instance.""" """Create the agent instance."""
model_name = _get_model_name(self.config, self.parent_model) model_name = _get_model_name(self.config, self.parent_model)
model = create_chat_model(name=model_name, thinking_enabled=False) model = create_chat_model(name=model_name, thinking_enabled=False, app_config=self.app_config)
from deerflow.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares from deerflow.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares
@@ -3,6 +3,7 @@
import logging import logging
from dataclasses import replace from dataclasses import replace
from deerflow.config.app_config import AppConfig
from deerflow.sandbox.security import is_host_bash_allowed from deerflow.sandbox.security import is_host_bash_allowed
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
from deerflow.subagents.config import SubagentConfig from deerflow.subagents.config import SubagentConfig
@@ -10,19 +11,17 @@ from deerflow.subagents.config import SubagentConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _build_custom_subagent_config(name: str) -> SubagentConfig | None: def _build_custom_subagent_config(name: str, app_config: AppConfig) -> SubagentConfig | None:
"""Build a SubagentConfig from config.yaml custom_agents section. """Build a SubagentConfig from config.yaml custom_agents section.
Args: Args:
name: The name of the custom subagent. name: The name of the custom subagent.
app_config: The resolved application config.
Returns: Returns:
SubagentConfig if found in custom_agents, None otherwise. SubagentConfig if found in custom_agents, None otherwise.
""" """
from deerflow.config.subagents_config import get_subagents_app_config custom = app_config.subagents.custom_agents.get(name)
app_config = get_subagents_app_config()
custom = app_config.custom_agents.get(name)
if custom is None: if custom is None:
return None return None
@@ -39,67 +38,44 @@ def _build_custom_subagent_config(name: str) -> SubagentConfig | None:
) )
def get_subagent_config(name: str) -> SubagentConfig | None: def get_subagent_config(name: str, app_config: AppConfig) -> SubagentConfig | None:
"""Get a subagent configuration by name, with config.yaml overrides applied. """Get a subagent configuration by name, with config.yaml overrides applied.
Resolution order (mirrors Codex's config layering): Resolution order (mirrors Codex's config layering):
1. Built-in subagents (general-purpose, bash) 1. Built-in subagents (general-purpose, bash)
2. Custom subagents from config.yaml custom_agents section 2. Custom subagents from config.yaml custom_agents section
3. Per-agent overrides from config.yaml agents section (timeout, max_turns, model, skills) 3. Per-agent overrides from config.yaml agents section (timeout, max_turns, model, skills)
Args:
name: The name of the subagent.
Returns:
SubagentConfig if found (with any config.yaml overrides applied), None otherwise.
""" """
# Step 1: Look up built-in, then fall back to custom_agents
config = BUILTIN_SUBAGENTS.get(name) config = BUILTIN_SUBAGENTS.get(name)
if config is None: if config is None:
config = _build_custom_subagent_config(name) config = _build_custom_subagent_config(name, app_config)
if config is None: if config is None:
return None return None
# Step 2: Apply per-agent overrides from config.yaml agents section. sub_config = app_config.subagents
# Only explicit per-agent overrides are applied here. Global defaults overrides: dict = {}
# (timeout_seconds, max_turns at the top level) apply to built-in agents
# but must NOT override custom agents' own values — custom agents define
# their own defaults in the custom_agents section.
# Lazy import to avoid circular deps.
from deerflow.config.subagents_config import get_subagents_app_config
app_config = get_subagents_app_config() # Timeout: subagents config supplies effective per-agent override or global default.
is_builtin = name in BUILTIN_SUBAGENTS effective_timeout = sub_config.get_timeout_for(name)
agent_override = app_config.agents.get(name) if effective_timeout != config.timeout_seconds:
logger.debug("Subagent '%s': timeout overridden (%ss -> %ss)", name, config.timeout_seconds, effective_timeout)
overrides["timeout_seconds"] = effective_timeout
overrides = {} # Max turns: subagents config supplies effective per-agent override or global default
# (falls back to ``config.max_turns`` when no override is configured).
# Timeout: per-agent override > global default (builtins only) > config's own value effective_max_turns = sub_config.get_max_turns_for(name, config.max_turns)
if agent_override is not None and agent_override.timeout_seconds is not None: if effective_max_turns != config.max_turns:
if agent_override.timeout_seconds != config.timeout_seconds: logger.debug("Subagent '%s': max_turns overridden (%s -> %s)", name, config.max_turns, effective_max_turns)
logger.debug("Subagent '%s': timeout overridden (%ss -> %ss)", name, config.timeout_seconds, agent_override.timeout_seconds) overrides["max_turns"] = effective_max_turns
overrides["timeout_seconds"] = agent_override.timeout_seconds
elif is_builtin and app_config.timeout_seconds != config.timeout_seconds:
logger.debug("Subagent '%s': timeout from global default (%ss -> %ss)", name, config.timeout_seconds, app_config.timeout_seconds)
overrides["timeout_seconds"] = app_config.timeout_seconds
# Max turns: per-agent override > global default (builtins only) > config's own value
if agent_override is not None and agent_override.max_turns is not None:
if agent_override.max_turns != config.max_turns:
logger.debug("Subagent '%s': max_turns overridden (%s -> %s)", name, config.max_turns, agent_override.max_turns)
overrides["max_turns"] = agent_override.max_turns
elif is_builtin and app_config.max_turns is not None and app_config.max_turns != config.max_turns:
logger.debug("Subagent '%s': max_turns from global default (%s -> %s)", name, config.max_turns, app_config.max_turns)
overrides["max_turns"] = app_config.max_turns
# Model: per-agent override only (no global default for model) # Model: per-agent override only (no global default for model)
effective_model = app_config.get_model_for(name) effective_model = sub_config.get_model_for(name)
if effective_model is not None and effective_model != config.model: if effective_model is not None and effective_model != config.model:
logger.debug("Subagent '%s': model overridden (%s -> %s)", name, config.model, effective_model) logger.debug("Subagent '%s': model overridden (%s -> %s)", name, config.model, effective_model)
overrides["model"] = effective_model overrides["model"] = effective_model
# Skills: per-agent override only (no global default for skills) # Skills: per-agent override only (no global default for skills)
effective_skills = app_config.get_skills_for(name) effective_skills = sub_config.get_skills_for(name)
if effective_skills is not None and effective_skills != config.skills: if effective_skills is not None and effective_skills != config.skills:
logger.debug("Subagent '%s': skills overridden (%s -> %s)", name, config.skills, effective_skills) logger.debug("Subagent '%s': skills overridden (%s -> %s)", name, config.skills, effective_skills)
overrides["skills"] = effective_skills overrides["skills"] = effective_skills
@@ -110,21 +86,21 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
return config return config
def list_subagents() -> list[SubagentConfig]: def list_subagents(app_config: AppConfig) -> list[SubagentConfig]:
"""List all available subagent configurations (with config.yaml overrides applied). """List all available subagent configurations (with config.yaml overrides applied).
Returns: Returns:
List of all registered SubagentConfig instances (built-in + custom). List of all registered SubagentConfig instances (built-in + custom).
""" """
configs = [] configs: list[SubagentConfig] = []
for name in get_subagent_names(): for name in get_subagent_names(app_config):
config = get_subagent_config(name) config = get_subagent_config(name, app_config)
if config is not None: if config is not None:
configs.append(config) configs.append(config)
return configs return configs
def get_subagent_names() -> list[str]: def get_subagent_names(app_config: AppConfig) -> list[str]:
"""Get all available subagent names (built-in + custom). """Get all available subagent names (built-in + custom).
Returns: Returns:
@@ -132,26 +108,22 @@ def get_subagent_names() -> list[str]:
""" """
names = list(BUILTIN_SUBAGENTS.keys()) names = list(BUILTIN_SUBAGENTS.keys())
# Merge custom_agents from config.yaml for custom_name in app_config.subagents.custom_agents:
from deerflow.config.subagents_config import get_subagents_app_config
app_config = get_subagents_app_config()
for custom_name in app_config.custom_agents:
if custom_name not in names: if custom_name not in names:
names.append(custom_name) names.append(custom_name)
return names return names
def get_available_subagent_names() -> list[str]: def get_available_subagent_names(app_config: AppConfig) -> list[str]:
"""Get subagent names that should be exposed to the active runtime. """Get subagent names that should be exposed to the active runtime.
Returns: Returns:
List of subagent names visible to the current sandbox configuration. List of subagent names visible to the current sandbox configuration.
""" """
names = get_subagent_names() names = get_subagent_names(app_config)
try: try:
host_bash_allowed = is_host_bash_allowed() host_bash_allowed = is_host_bash_allowed(app_config)
except Exception: except Exception:
logger.debug("Could not determine host bash availability; exposing all subagents") logger.debug("Could not determine host bash availability; exposing all subagents")
return names return names
@@ -52,7 +52,7 @@ def _normalize_presented_filepath(
if runtime.state is None: if runtime.state is None:
raise ValueError("Thread runtime state is not available") raise ValueError("Thread runtime state is not available")
thread_id = _get_thread_id(runtime) thread_id = runtime.context.thread_id
if not thread_id: if not thread_id:
raise ValueError("Thread ID is not available in runtime context or runtime config") raise ValueError("Thread ID is not available in runtime context or runtime config")
@@ -66,10 +66,7 @@ def _normalize_presented_filepath(
virtual_prefix = VIRTUAL_PATH_PREFIX.lstrip("/") virtual_prefix = VIRTUAL_PATH_PREFIX.lstrip("/")
if stripped == virtual_prefix or stripped.startswith(virtual_prefix + "/"): if stripped == virtual_prefix or stripped.startswith(virtual_prefix + "/"):
try: actual_path = get_paths().resolve_virtual_path(thread_id, filepath, user_id=get_effective_user_id())
actual_path = get_paths().resolve_virtual_path(thread_id, filepath, user_id=get_effective_user_id())
except TypeError:
actual_path = get_paths().resolve_virtual_path(thread_id, filepath)
else: else:
actual_path = Path(filepath).expanduser().resolve() actual_path = Path(filepath).expanduser().resolve()
@@ -27,7 +27,7 @@ def setup_agent(
skills: Optional list of skill names this agent should use. None means use all enabled skills, empty list means no skills. skills: Optional list of skill names this agent should use. None means use all enabled skills, empty list means no skills.
""" """
agent_name: str | None = runtime.context.get("agent_name") if runtime.context else None agent_name: str | None = runtime.context.agent_name
agent_dir = None agent_dir = None
is_new_dir = False is_new_dir = False
@@ -11,6 +11,7 @@ from langgraph.config import get_stream_writer
from langgraph.typing import ContextT from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadState from deerflow.agents.thread_state import ThreadState
from deerflow.config.deer_flow_context import resolve_context
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
from deerflow.subagents.executor import SubagentStatus, cleanup_background_task, get_background_task_result, request_cancel_background_task from deerflow.subagents.executor import SubagentStatus, cleanup_background_task, get_background_task_result, request_cancel_background_task
@@ -74,14 +75,15 @@ async def task_tool(
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD. subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
max_turns: Optional maximum number of agent turns. Defaults to subagent's configured max. max_turns: Optional maximum number of agent turns. Defaults to subagent's configured max.
""" """
available_subagent_names = get_available_subagent_names() ctx = resolve_context(runtime)
available_subagent_names = get_available_subagent_names(ctx.app_config)
# Get subagent configuration # Get subagent configuration
config = get_subagent_config(subagent_type) config = get_subagent_config(subagent_type, ctx.app_config)
if config is None: if config is None:
available = ", ".join(available_subagent_names) available = ", ".join(available_subagent_names)
return f"Error: Unknown subagent type '{subagent_type}'. Available: {available}" return f"Error: Unknown subagent type '{subagent_type}'. Available: {available}"
if subagent_type == "bash" and not is_host_bash_allowed(): if subagent_type == "bash" and not is_host_bash_allowed(ctx.app_config):
return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}" return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}"
# Build config overrides # Build config overrides
@@ -105,9 +107,7 @@ async def task_tool(
if runtime is not None: if runtime is not None:
sandbox_state = runtime.state.get("sandbox") sandbox_state = runtime.state.get("sandbox")
thread_data = runtime.state.get("thread_data") thread_data = runtime.state.get("thread_data")
thread_id = runtime.context.get("thread_id") if runtime.context else None thread_id = runtime.context.thread_id
if thread_id is None:
thread_id = runtime.config.get("configurable", {}).get("thread_id")
# Try to get parent model from configurable # Try to get parent model from configurable
metadata = runtime.config.get("metadata", {}) metadata = runtime.config.get("metadata", {})
@@ -131,12 +131,13 @@ async def task_tool(
parent_tool_groups = metadata.get("tool_groups") parent_tool_groups = metadata.get("tool_groups")
# Subagents should not have subagent tools enabled (prevent recursive nesting) # Subagents should not have subagent tools enabled (prevent recursive nesting)
tools = get_available_tools(model_name=parent_model, groups=parent_tool_groups, subagent_enabled=False) tools = get_available_tools(model_name=parent_model, groups=parent_tool_groups, subagent_enabled=False, app_config=ctx.app_config)
# Create executor # Create executor
executor = SubagentExecutor( executor = SubagentExecutor(
config=config, config=config,
tools=tools, tools=tools,
app_config=ctx.app_config,
parent_model=parent_model, parent_model=parent_model,
sandbox_state=sandbox_state, sandbox_state=sandbox_state,
thread_data=thread_data, thread_data=thread_data,
@@ -5,7 +5,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import shutil import shutil
from typing import Any from typing import TYPE_CHECKING, Any
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
from langchain.tools import ToolRuntime, tool from langchain.tools import ToolRuntime, tool
@@ -13,6 +13,9 @@ from langgraph.typing import ContextT
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.agents.thread_state import ThreadState from deerflow.agents.thread_state import ThreadState
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
from deerflow.mcp.tools import _make_sync_tool_wrapper from deerflow.mcp.tools import _make_sync_tool_wrapper
from deerflow.skills.manager import ( from deerflow.skills.manager import (
append_history, append_history,
@@ -45,9 +48,7 @@ def _get_lock(name: str) -> asyncio.Lock:
def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState] | None) -> str | None: def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState] | None) -> str | None:
if runtime is None: if runtime is None:
return None return None
if runtime.context and runtime.context.get("thread_id"): return runtime.context.thread_id or None
return runtime.context.get("thread_id")
return runtime.config.get("configurable", {}).get("thread_id")
def _history_record(*, action: str, file_path: str, prev_content: str | None, new_content: str | None, thread_id: str | None, scanner: dict[str, Any]) -> dict[str, Any]: def _history_record(*, action: str, file_path: str, prev_content: str | None, new_content: str | None, thread_id: str | None, scanner: dict[str, Any]) -> dict[str, Any]:
@@ -62,8 +63,8 @@ def _history_record(*, action: str, file_path: str, prev_content: str | None, ne
} }
async def _scan_or_raise(content: str, *, executable: bool, location: str) -> dict[str, str]: async def _scan_or_raise(app_config: "AppConfig", content: str, *, executable: bool, location: str) -> dict[str, str]:
result = await scan_skill_content(content, executable=executable, location=location) result = await scan_skill_content(app_config, content, executable=executable, location=location)
if result.decision == "block": if result.decision == "block":
raise ValueError(f"Security scan blocked the write: {result.reason}") raise ValueError(f"Security scan blocked the write: {result.reason}")
if executable and result.decision != "allow": if executable and result.decision != "allow":
@@ -96,50 +97,55 @@ async def _skill_manage_impl(
replace: Replacement text for patch. replace: Replacement text for patch.
expected_count: Optional expected number of replacements for patch. expected_count: Optional expected number of replacements for patch.
""" """
from deerflow.config.deer_flow_context import resolve_context
name = validate_skill_name(name) name = validate_skill_name(name)
lock = _get_lock(name) lock = _get_lock(name)
thread_id = _get_thread_id(runtime) thread_id = _get_thread_id(runtime)
app_config = resolve_context(runtime).app_config
async with lock: async with lock:
if action == "create": if action == "create":
if await _to_thread(custom_skill_exists, name): if await _to_thread(custom_skill_exists, name, app_config):
raise ValueError(f"Custom skill '{name}' already exists.") raise ValueError(f"Custom skill '{name}' already exists.")
if content is None: if content is None:
raise ValueError("content is required for create.") raise ValueError("content is required for create.")
await _to_thread(validate_skill_markdown_content, name, content) await _to_thread(validate_skill_markdown_content, name, content)
scan = await _scan_or_raise(content, executable=False, location=f"{name}/SKILL.md") scan = await _scan_or_raise(app_config, content, executable=False, location=f"{name}/SKILL.md")
skill_file = await _to_thread(get_custom_skill_file, name) skill_file = await _to_thread(get_custom_skill_file, name, app_config)
await _to_thread(atomic_write, skill_file, content) await _to_thread(atomic_write, skill_file, content)
await _to_thread( await _to_thread(
append_history, append_history,
name, name,
_history_record(action="create", file_path="SKILL.md", prev_content=None, new_content=content, thread_id=thread_id, scanner=scan), _history_record(action="create", file_path="SKILL.md", prev_content=None, new_content=content, thread_id=thread_id, scanner=scan),
app_config,
) )
await refresh_skills_system_prompt_cache_async() await refresh_skills_system_prompt_cache_async(app_config)
return f"Created custom skill '{name}'." return f"Created custom skill '{name}'."
if action == "edit": if action == "edit":
await _to_thread(ensure_custom_skill_is_editable, name) await _to_thread(ensure_custom_skill_is_editable, name, app_config)
if content is None: if content is None:
raise ValueError("content is required for edit.") raise ValueError("content is required for edit.")
await _to_thread(validate_skill_markdown_content, name, content) await _to_thread(validate_skill_markdown_content, name, content)
scan = await _scan_or_raise(content, executable=False, location=f"{name}/SKILL.md") scan = await _scan_or_raise(app_config, content, executable=False, location=f"{name}/SKILL.md")
skill_file = await _to_thread(get_custom_skill_file, name) skill_file = await _to_thread(get_custom_skill_file, name, app_config)
prev_content = await _to_thread(skill_file.read_text, encoding="utf-8") prev_content = await _to_thread(skill_file.read_text, encoding="utf-8")
await _to_thread(atomic_write, skill_file, content) await _to_thread(atomic_write, skill_file, content)
await _to_thread( await _to_thread(
append_history, append_history,
name, name,
_history_record(action="edit", file_path="SKILL.md", prev_content=prev_content, new_content=content, thread_id=thread_id, scanner=scan), _history_record(action="edit", file_path="SKILL.md", prev_content=prev_content, new_content=content, thread_id=thread_id, scanner=scan),
app_config,
) )
await refresh_skills_system_prompt_cache_async() await refresh_skills_system_prompt_cache_async(app_config)
return f"Updated custom skill '{name}'." return f"Updated custom skill '{name}'."
if action == "patch": if action == "patch":
await _to_thread(ensure_custom_skill_is_editable, name) await _to_thread(ensure_custom_skill_is_editable, name, app_config)
if find is None or replace is None: if find is None or replace is None:
raise ValueError("find and replace are required for patch.") raise ValueError("find and replace are required for patch.")
skill_file = await _to_thread(get_custom_skill_file, name) skill_file = await _to_thread(get_custom_skill_file, name, app_config)
prev_content = await _to_thread(skill_file.read_text, encoding="utf-8") prev_content = await _to_thread(skill_file.read_text, encoding="utf-8")
occurrences = prev_content.count(find) occurrences = prev_content.count(find)
if occurrences == 0: if occurrences == 0:
@@ -149,51 +155,54 @@ async def _skill_manage_impl(
replacement_count = expected_count if expected_count is not None else 1 replacement_count = expected_count if expected_count is not None else 1
new_content = prev_content.replace(find, replace, replacement_count) new_content = prev_content.replace(find, replace, replacement_count)
await _to_thread(validate_skill_markdown_content, name, new_content) await _to_thread(validate_skill_markdown_content, name, new_content)
scan = await _scan_or_raise(new_content, executable=False, location=f"{name}/SKILL.md") scan = await _scan_or_raise(app_config, new_content, executable=False, location=f"{name}/SKILL.md")
await _to_thread(atomic_write, skill_file, new_content) await _to_thread(atomic_write, skill_file, new_content)
await _to_thread( await _to_thread(
append_history, append_history,
name, name,
_history_record(action="patch", file_path="SKILL.md", prev_content=prev_content, new_content=new_content, thread_id=thread_id, scanner=scan), _history_record(action="patch", file_path="SKILL.md", prev_content=prev_content, new_content=new_content, thread_id=thread_id, scanner=scan),
app_config,
) )
await refresh_skills_system_prompt_cache_async() await refresh_skills_system_prompt_cache_async(app_config)
return f"Patched custom skill '{name}' ({replacement_count} replacement(s) applied, {occurrences} match(es) found)." return f"Patched custom skill '{name}' ({replacement_count} replacement(s) applied, {occurrences} match(es) found)."
if action == "delete": if action == "delete":
await _to_thread(ensure_custom_skill_is_editable, name) await _to_thread(ensure_custom_skill_is_editable, name, app_config)
skill_dir = await _to_thread(get_custom_skill_dir, name) skill_dir = await _to_thread(get_custom_skill_dir, name, app_config)
prev_content = await _to_thread(read_custom_skill_content, name) prev_content = await _to_thread(read_custom_skill_content, name, app_config)
await _to_thread( await _to_thread(
append_history, append_history,
name, name,
_history_record(action="delete", file_path="SKILL.md", prev_content=prev_content, new_content=None, thread_id=thread_id, scanner={"decision": "allow", "reason": "Deletion requested."}), _history_record(action="delete", file_path="SKILL.md", prev_content=prev_content, new_content=None, thread_id=thread_id, scanner={"decision": "allow", "reason": "Deletion requested."}),
app_config,
) )
await _to_thread(shutil.rmtree, skill_dir) await _to_thread(shutil.rmtree, skill_dir)
await refresh_skills_system_prompt_cache_async() await refresh_skills_system_prompt_cache_async(app_config)
return f"Deleted custom skill '{name}'." return f"Deleted custom skill '{name}'."
if action == "write_file": if action == "write_file":
await _to_thread(ensure_custom_skill_is_editable, name) await _to_thread(ensure_custom_skill_is_editable, name, app_config)
if path is None or content is None: if path is None or content is None:
raise ValueError("path and content are required for write_file.") raise ValueError("path and content are required for write_file.")
target = await _to_thread(ensure_safe_support_path, name, path) target = await _to_thread(ensure_safe_support_path, name, path, app_config)
exists = await _to_thread(target.exists) exists = await _to_thread(target.exists)
prev_content = await _to_thread(target.read_text, encoding="utf-8") if exists else None prev_content = await _to_thread(target.read_text, encoding="utf-8") if exists else None
executable = "scripts/" in path or path.startswith("scripts/") executable = "scripts/" in path or path.startswith("scripts/")
scan = await _scan_or_raise(content, executable=executable, location=f"{name}/{path}") scan = await _scan_or_raise(app_config, content, executable=executable, location=f"{name}/{path}")
await _to_thread(atomic_write, target, content) await _to_thread(atomic_write, target, content)
await _to_thread( await _to_thread(
append_history, append_history,
name, name,
_history_record(action="write_file", file_path=path, prev_content=prev_content, new_content=content, thread_id=thread_id, scanner=scan), _history_record(action="write_file", file_path=path, prev_content=prev_content, new_content=content, thread_id=thread_id, scanner=scan),
app_config,
) )
return f"Wrote '{path}' for custom skill '{name}'." return f"Wrote '{path}' for custom skill '{name}'."
if action == "remove_file": if action == "remove_file":
await _to_thread(ensure_custom_skill_is_editable, name) await _to_thread(ensure_custom_skill_is_editable, name, app_config)
if path is None: if path is None:
raise ValueError("path is required for remove_file.") raise ValueError("path is required for remove_file.")
target = await _to_thread(ensure_safe_support_path, name, path) target = await _to_thread(ensure_safe_support_path, name, path, app_config)
if not await _to_thread(target.exists): if not await _to_thread(target.exists):
raise FileNotFoundError(f"Supporting file '{path}' not found for skill '{name}'.") raise FileNotFoundError(f"Supporting file '{path}' not found for skill '{name}'.")
prev_content = await _to_thread(target.read_text, encoding="utf-8") prev_content = await _to_thread(target.read_text, encoding="utf-8")
@@ -202,10 +211,11 @@ async def _skill_manage_impl(
append_history, append_history,
name, name,
_history_record(action="remove_file", file_path=path, prev_content=prev_content, new_content=None, thread_id=thread_id, scanner={"decision": "allow", "reason": "Deletion requested."}), _history_record(action="remove_file", file_path=path, prev_content=prev_content, new_content=None, thread_id=thread_id, scanner={"decision": "allow", "reason": "Deletion requested."}),
app_config,
) )
return f"Removed '{path}' from custom skill '{name}'." return f"Removed '{path}' from custom skill '{name}'."
if await _to_thread(public_skill_exists, name): if await _to_thread(public_skill_exists, name, app_config):
raise ValueError(f"'{name}' is a built-in skill. To customise it, create a new skill with the same name under skills/custom/.") raise ValueError(f"'{name}' is a built-in skill. To customise it, create a new skill with the same name under skills/custom/.")
raise ValueError(f"Unsupported action '{action}'.") raise ValueError(f"Unsupported action '{action}'.")
@@ -2,7 +2,7 @@ import logging
from langchain.tools import BaseTool from langchain.tools import BaseTool
from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_variable from deerflow.reflection import resolve_variable
from deerflow.sandbox.security import is_host_bash_allowed from deerflow.sandbox.security import is_host_bash_allowed
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
@@ -37,6 +37,8 @@ def get_available_tools(
include_mcp: bool = True, include_mcp: bool = True,
model_name: str | None = None, model_name: str | None = None,
subagent_enabled: bool = False, subagent_enabled: bool = False,
*,
app_config: AppConfig,
) -> list[BaseTool]: ) -> list[BaseTool]:
"""Get all available tools from config. """Get all available tools from config.
@@ -48,11 +50,12 @@ def get_available_tools(
include_mcp: Whether to include tools from MCP servers (default: True). include_mcp: Whether to include tools from MCP servers (default: True).
model_name: Optional model name to determine if vision tools should be included. model_name: Optional model name to determine if vision tools should be included.
subagent_enabled: Whether to include subagent tools (task, task_status). subagent_enabled: Whether to include subagent tools (task, task_status).
app_config: Application config required.
Returns: Returns:
List of available tools. List of available tools.
""" """
config = get_app_config() config = app_config
tool_configs = [tool for tool in config.tools if groups is None or tool.group in groups] tool_configs = [tool for tool in config.tools if groups is None or tool.group in groups]
# Do not expose host bash by default when LocalSandboxProvider is active. # Do not expose host bash by default when LocalSandboxProvider is active.
@@ -138,10 +141,9 @@ def get_available_tools(
# Add invoke_acp_agent tool if any ACP agents are configured # Add invoke_acp_agent tool if any ACP agents are configured
acp_tools: list[BaseTool] = [] acp_tools: list[BaseTool] = []
try: try:
from deerflow.config.acp_config import get_acp_agents
from deerflow.tools.builtins.invoke_acp_agent_tool import build_invoke_acp_agent_tool from deerflow.tools.builtins.invoke_acp_agent_tool import build_invoke_acp_agent_tool
acp_agents = get_acp_agents() acp_agents = config.acp_agents
if acp_agents: if acp_agents:
acp_tools.append(build_invoke_acp_agent_tool(acp_agents)) acp_tools.append(build_invoke_acp_agent_tool(acp_agents))
logger.info(f"Including invoke_acp_agent tool ({len(acp_agents)} agent(s): {list(acp_agents.keys())})") logger.info(f"Including invoke_acp_agent tool ({len(acp_agents)} agent(s): {list(acp_agents.keys())})")
@@ -19,8 +19,6 @@ import logging
import re import re
from pathlib import Path from pathlib import Path
from deerflow.config.app_config import get_app_config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# File extensions that should be converted to markdown # File extensions that should be converted to markdown
@@ -135,7 +133,7 @@ def _do_convert(file_path: Path, pdf_converter: str) -> str:
return _convert_with_markitdown(file_path) return _convert_with_markitdown(file_path)
async def convert_file_to_markdown(file_path: Path) -> Path | None: async def convert_file_to_markdown(file_path: Path, app_config: object | None = None) -> Path | None:
"""Convert a supported document file to Markdown. """Convert a supported document file to Markdown.
PDF files are handled with a two-converter strategy (see module docstring). PDF files are handled with a two-converter strategy (see module docstring).
@@ -144,12 +142,14 @@ async def convert_file_to_markdown(file_path: Path) -> Path | None:
Args: Args:
file_path: Path to the file to convert. file_path: Path to the file to convert.
app_config: Optional AppConfig (for pdf_converter preference). When
omitted, defaults to ``auto``.
Returns: Returns:
Path to the generated .md file, or None if conversion failed. Path to the generated .md file, or None if conversion failed.
""" """
try: try:
pdf_converter = _get_pdf_converter() pdf_converter = _get_pdf_converter(app_config)
file_size = file_path.stat().st_size file_size = file_path.stat().st_size
if file_size > _ASYNC_THRESHOLD_BYTES: if file_size > _ASYNC_THRESHOLD_BYTES:
@@ -288,28 +288,20 @@ def extract_outline(md_path: Path) -> list[dict]:
return outline return outline
def _get_uploads_config_value(key: str, default: object) -> object: def _get_pdf_converter(app_config: object | None) -> str:
"""Read a value from the uploads config, supporting dict and attribute access."""
cfg = get_app_config()
uploads_cfg = getattr(cfg, "uploads", None)
if isinstance(uploads_cfg, dict):
return uploads_cfg.get(key, default)
return getattr(uploads_cfg, key, default)
def _get_pdf_converter() -> str:
"""Read pdf_converter setting from app config, defaulting to 'auto'. """Read pdf_converter setting from app config, defaulting to 'auto'.
Normalizes the value to lowercase and validates it against the allowed set Normalizes the value to lowercase and validates it against the allowed set
so that values like 'AUTO' or 'MarkItDown' from config.yaml don't silently so that values like 'AUTO' or 'MarkItDown' from config.yaml don't silently
fall through to unexpected behaviour. fall through to unexpected behaviour.
""" """
try: if app_config is None:
raw = str(_get_uploads_config_value("pdf_converter", "auto")).strip().lower() return "auto"
if raw not in _ALLOWED_PDF_CONVERTERS: uploads_cfg = getattr(app_config, "uploads", None)
logger.warning("Invalid pdf_converter value %r; falling back to 'auto'", raw) if uploads_cfg is None:
return "auto" return "auto"
return raw raw = str(getattr(uploads_cfg, "pdf_converter", "auto")).strip().lower()
except Exception: if raw not in _ALLOWED_PDF_CONVERTERS:
pass logger.warning("Invalid pdf_converter value %r; falling back to 'auto'", raw)
return "auto" return "auto"
return raw
-3
View File
@@ -38,9 +38,6 @@ markers = [
"no_auto_user: disable the conftest autouse contextvar fixture for this test", "no_auto_user: disable the conftest autouse contextvar fixture for this test",
] ]
[tool.uv]
index-url = "https://pypi.org/simple"
[tool.uv.workspace] [tool.uv.workspace]
members = ["packages/harness"] members = ["packages/harness"]
+2 -1
View File
@@ -5,10 +5,11 @@ Usage:
The script is idempotent re-running it after a successful migration is a no-op. The script is idempotent re-running it after a successful migration is a no-op.
""" """
import argparse import argparse
import json
import logging import logging
import shutil import shutil
from pathlib import Path
from deerflow.config.paths import Paths, get_paths from deerflow.config.paths import Paths, get_paths
+6 -1
View File
@@ -29,6 +29,7 @@ apps with the real middleware — those should not use this module.
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
from typing import ParamSpec, TypeVar
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4 from uuid import uuid4
@@ -112,7 +113,11 @@ def make_authed_test_app(
return app return app
def call_unwrapped[*P, R](decorated: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs) -> R: _P = ParamSpec("_P")
_R = TypeVar("_R")
def call_unwrapped(decorated: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs) -> _R:
"""Invoke the underlying function of a ``@require_permission``-decorated route. """Invoke the underlying function of a ``@require_permission``-decorated route.
``functools.wraps`` sets ``__wrapped__`` on each layer; we walk all ``functools.wraps`` sets ``__wrapped__`` on each layer; we walk all

Some files were not shown because too many files have changed in this diff Show More