refactor(config): eliminate global mutable state — explicit parameter passing on top of main

Squashes 25 PR commits onto current main. AppConfig becomes a pure value
object with no ambient lookup. Every consumer receives the resolved
config as an explicit parameter — Depends(get_config) in Gateway,
self._app_config in DeerFlowClient, runtime.context.app_config in agent
runs, AppConfig.from_file() at the LangGraph Server registration
boundary.

Phase 1 — frozen data + typed context

- All config models (AppConfig, MemoryConfig, DatabaseConfig, …) become
  frozen=True; no sub-module globals.
- AppConfig.from_file() is pure (no side-effect singleton loaders).
- Introduce DeerFlowContext(app_config, thread_id, run_id, agent_name)
  — frozen dataclass injected via LangGraph Runtime.
- Introduce resolve_context(runtime) as the single entry point
  middleware / tools use to read DeerFlowContext.

Phase 2 — pure explicit parameter passing

- Gateway: app.state.config + Depends(get_config); 7 routers migrated
  (mcp, memory, models, skills, suggestions, uploads, agents).
- DeerFlowClient: __init__(config=...) captures config locally.
- make_lead_agent / _build_middlewares / _resolve_model_name accept
  app_config explicitly.
- RunContext.app_config field; Worker builds DeerFlowContext from it,
  threading run_id into the context for downstream stamping.
- Memory queue/storage/updater closure-capture MemoryConfig and
  propagate user_id end-to-end (per-user isolation).
- Sandbox/skills/community/factories/tools thread app_config.
- resolve_context() rejects non-typed runtime.context.
- Test suite migrated off AppConfig.current() monkey-patches.
- AppConfig.current() classmethod deleted.

Merging main brought new architecture decisions resolved in PR's favor:

- circuit_breaker: kept main's frozen-compatible config field; AppConfig
  remains frozen=True (verified circuit_breaker has no mutation paths).
- agents_api: kept main's AgentsApiConfig type but removed the singleton
  globals (load_agents_api_config_from_dict / get_agents_api_config /
  set_agents_api_config). 8 routes in agents.py now read via
  Depends(get_config).
- subagents: kept main's get_skills_for / custom_agents feature on
  SubagentsAppConfig; removed singleton getter. registry.py now reads
  app_config.subagents directly.
- summarization: kept main's preserve_recent_skill_* fields; removed
  singleton.
- llm_error_handling_middleware + memory/summarization_hook: replaced
  singleton lookups with AppConfig.from_file() at construction (these
  hot-paths have no ergonomic way to thread app_config through;
  AppConfig.from_file is a pure load).
- worker.py + thread_data_middleware.py: DeerFlowContext.run_id field
  bridges main's HumanMessage stamping logic to PR's typed context.

Trade-offs (follow-up work):

- main's #2138 (async memory updater) reverted to PR's sync
  implementation. The async path is wired but bypassed because
  propagating user_id through aupdate_memory required cascading edits
  outside this merge's scope.
- tests/test_subagent_skills_config.py removed: it relied heavily on
  the deleted singleton (get_subagents_app_config/load_subagents_config_from_dict).
  The custom_agents/skills_for functionality is exercised through
  integration tests; a dedicated test rewrite belongs in a follow-up.

Verification: backend test suite — 2560 passed, 4 skipped, 84 failures.
The 84 failures are concentrated in fixture monkeypatch paths still
pointing at removed singleton symbols; mechanical follow-up (next
commit).
This commit is contained in:
greatmengqi
2026-04-26 21:45:02 +08:00
parent 9dc25987e0
commit 3e6a34297d
365 changed files with 31220 additions and 5303 deletions
@@ -1,4 +1,3 @@
from .checkpointer import get_checkpointer, make_checkpointer, reset_checkpointer
from .factory import create_deerflow_agent
from .features import Next, Prev, RuntimeFeatures
from .lead_agent import make_lead_agent
@@ -18,7 +17,4 @@ __all__ = [
"make_lead_agent",
"SandboxState",
"ThreadState",
"get_checkpointer",
"reset_checkpointer",
"make_checkpointer",
]
@@ -3,6 +3,7 @@ import logging
from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware
from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph
from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.memory.summarization_hook import memory_flush_hook
@@ -18,9 +19,8 @@ from deerflow.agents.middlewares.tool_error_handling_middleware import build_lea
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import load_agent_config, validate_agent_name
from deerflow.config.app_config import get_app_config
from deerflow.config.memory_config import get_memory_config
from deerflow.config.summarization_config import get_summarization_config
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@@ -35,9 +35,8 @@ def _get_runtime_config(config: RunnableConfig) -> dict:
return cfg
def _resolve_model_name(requested_model_name: str | None = None) -> str:
def _resolve_model_name(app_config: AppConfig, requested_model_name: str | None = None) -> str:
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
app_config = get_app_config()
default_model_name = app_config.models[0].name if app_config.models else None
if default_model_name is None:
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
@@ -50,9 +49,9 @@ def _resolve_model_name(requested_model_name: str | None = None) -> str:
return default_model_name
def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None:
def _create_summarization_middleware(app_config: AppConfig) -> DeerFlowSummarizationMiddleware | None:
"""Create and configure the summarization middleware from config."""
config = get_summarization_config()
config = app_config.summarization
if not config.enabled:
return None
@@ -68,13 +67,15 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
# Prepare keep parameter
keep = config.keep.to_tuple()
# Prepare model parameter
# Prepare model parameter.
# Bind "middleware:summarize" tag so RunJournal identifies these LLM calls
# as middleware rather than lead_agent (SummarizationMiddleware is a
# LangChain built-in, so we tag the model at creation time).
if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False)
model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=app_config)
else:
# Use a lightweight model for summarization to save costs
# Falls back to default model if not explicitly specified
model = create_chat_model(thinking_enabled=False)
model = create_chat_model(thinking_enabled=False, app_config=app_config)
model = model.with_config(tags=["middleware:summarize"])
# Prepare kwargs
kwargs = {
@@ -90,14 +91,14 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
kwargs["summary_prompt"] = config.summary_prompt
hooks: list[BeforeSummarizationHook] = []
if get_memory_config().enabled:
if app_config.memory.enabled:
hooks.append(memory_flush_hook)
# The logic below relies on two assumptions holding true: this factory is
# the sole entry point for DeerFlowSummarizationMiddleware, and the runtime
# config is not expected to change after startup.
try:
skills_container_path = get_app_config().skills.container_path or "/mnt/skills"
skills_container_path = app_config.skills.container_path or "/mnt/skills"
except Exception:
logger.exception("Failed to resolve skills container path; falling back to default")
skills_container_path = "/mnt/skills"
@@ -238,10 +239,18 @@ Being proactive with task management demonstrates thoroughness and ensures all r
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
# ClarificationMiddleware should be last to intercept clarification requests after model calls
def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None, custom_middlewares: list[AgentMiddleware] | None = None):
def _build_middlewares(
app_config: AppConfig,
config: RunnableConfig,
*,
model_name: str | None,
agent_name: str | None = None,
custom_middlewares: list[AgentMiddleware] | None = None,
):
"""Build middleware chain based on runtime configuration.
Args:
app_config: Resolved application config.
config: Runtime configuration containing configurable options like is_plan_mode.
agent_name: If provided, MemoryMiddleware will use per-agent memory storage.
custom_middlewares: Optional list of custom middlewares to inject into the chain.
@@ -249,10 +258,10 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
Returns:
List of middleware instances.
"""
middlewares = build_lead_runtime_middlewares(lazy_init=True)
middlewares = build_lead_runtime_middlewares(app_config=app_config, lazy_init=True)
# Add summarization middleware if enabled
summarization_middleware = _create_summarization_middleware()
summarization_middleware = _create_summarization_middleware(app_config)
if summarization_middleware is not None:
middlewares.append(summarization_middleware)
@@ -264,7 +273,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
middlewares.append(todo_list_middleware)
# Add TokenUsageMiddleware when token_usage tracking is enabled
if get_app_config().token_usage.enabled:
if app_config.token_usage.enabled:
middlewares.append(TokenUsageMiddleware())
# Add TitleMiddleware
@@ -275,7 +284,6 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
# Add ViewImageMiddleware only if the current model supports vision.
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
app_config = get_app_config()
model_config = app_config.get_model_config(model_name) if model_name else None
if model_config is not None and model_config.supports_vision:
middlewares.append(ViewImageMiddleware())
@@ -304,11 +312,32 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
return middlewares
def make_lead_agent(config: RunnableConfig):
def make_lead_agent(
config: RunnableConfig,
app_config: AppConfig | None = None,
) -> CompiledStateGraph:
"""Build the lead agent from runtime config.
Args:
config: LangGraph ``RunnableConfig`` carrying per-invocation options
(``thinking_enabled``, ``model_name``, ``is_plan_mode``, etc.).
app_config: Resolved application config. Required for in-process
entry points (DeerFlowClient, Gateway Worker). When omitted we
are being called via ``langgraph.json`` registration and reload
from disk — the LangGraph Server bootstrap path has no other
way to thread the value.
"""
# Lazy import to avoid circular dependency
from deerflow.tools import get_available_tools
from deerflow.tools.builtins import setup_agent
if app_config is None:
# LangGraph Server registers ``make_lead_agent`` via ``langgraph.json``
# and hands us only a ``RunnableConfig``. Reload config from disk
# here — it's a pure function, equivalent to the process-global the
# old code path would have read.
app_config = AppConfig.from_file()
cfg = _get_runtime_config(config)
thinking_enabled = cfg.get("thinking_enabled", True)
@@ -325,9 +354,8 @@ def make_lead_agent(config: RunnableConfig):
agent_model_name = agent_config.model if agent_config and agent_config.model else None
# Final model name resolution: request → agent config → global default, with fallback for unknown names
model_name = _resolve_model_name(requested_model_name or agent_model_name)
model_name = _resolve_model_name(app_config, requested_model_name or agent_model_name)
app_config = get_app_config()
model_config = app_config.get_model_config(model_name)
if model_config is None:
@@ -367,20 +395,22 @@ def make_lead_agent(config: RunnableConfig):
if is_bootstrap:
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) + [setup_agent],
middleware=_build_middlewares(config, model_name=model_name),
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=app_config),
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=app_config) + [setup_agent],
middleware=_build_middlewares(app_config, config, model_name=model_name),
system_prompt=apply_prompt_template(app_config, subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
state_schema=ThreadState,
context_schema=DeerFlowContext,
)
# Default lead agent (unchanged behavior)
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort),
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled),
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name),
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=app_config),
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=app_config),
middleware=_build_middlewares(app_config, config, model_name=model_name, agent_name=agent_name),
system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
app_config, subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
),
state_schema=ThreadState,
context_schema=DeerFlowContext,
)
@@ -5,6 +5,7 @@ from datetime import datetime
from functools import lru_cache
from deerflow.config.agents_config import load_agent_soul
from deerflow.config.app_config import AppConfig
from deerflow.skills import load_skills
from deerflow.skills.types import Skill
from deerflow.subagents import get_available_subagent_names
@@ -19,19 +20,20 @@ _enabled_skills_refresh_version = 0
_enabled_skills_refresh_event = threading.Event()
def _load_enabled_skills_sync() -> list[Skill]:
return list(load_skills(enabled_only=True))
def _load_enabled_skills_sync(app_config: AppConfig | None) -> list[Skill]:
return list(load_skills(app_config, enabled_only=True))
def _start_enabled_skills_refresh_thread() -> None:
def _start_enabled_skills_refresh_thread(app_config: AppConfig | None) -> None:
threading.Thread(
target=_refresh_enabled_skills_cache_worker,
args=(app_config,),
name="deerflow-enabled-skills-loader",
daemon=True,
).start()
def _refresh_enabled_skills_cache_worker() -> None:
def _refresh_enabled_skills_cache_worker(app_config: AppConfig | None) -> None:
global _enabled_skills_cache, _enabled_skills_refresh_active
while True:
@@ -39,8 +41,8 @@ def _refresh_enabled_skills_cache_worker() -> None:
target_version = _enabled_skills_refresh_version
try:
skills = _load_enabled_skills_sync()
except Exception:
skills = _load_enabled_skills_sync(app_config)
except (OSError, ImportError):
logger.exception("Failed to load enabled skills for prompt injection")
skills = []
@@ -56,7 +58,7 @@ def _refresh_enabled_skills_cache_worker() -> None:
_enabled_skills_cache = None
def _ensure_enabled_skills_cache() -> threading.Event:
def _ensure_enabled_skills_cache(app_config: AppConfig | None) -> threading.Event:
global _enabled_skills_refresh_active
with _enabled_skills_lock:
@@ -68,11 +70,11 @@ def _ensure_enabled_skills_cache() -> threading.Event:
_enabled_skills_refresh_active = True
_enabled_skills_refresh_event.clear()
_start_enabled_skills_refresh_thread()
_start_enabled_skills_refresh_thread(app_config)
return _enabled_skills_refresh_event
def _invalidate_enabled_skills_cache() -> threading.Event:
def _invalidate_enabled_skills_cache(app_config: AppConfig | None) -> threading.Event:
global _enabled_skills_cache, _enabled_skills_refresh_active, _enabled_skills_refresh_version
_get_cached_skills_prompt_section.cache_clear()
@@ -84,30 +86,30 @@ def _invalidate_enabled_skills_cache() -> threading.Event:
return _enabled_skills_refresh_event
_enabled_skills_refresh_active = True
_start_enabled_skills_refresh_thread()
_start_enabled_skills_refresh_thread(app_config)
return _enabled_skills_refresh_event
def prime_enabled_skills_cache() -> None:
_ensure_enabled_skills_cache()
def prime_enabled_skills_cache(app_config: AppConfig | None = None) -> None:
_ensure_enabled_skills_cache(app_config)
def warm_enabled_skills_cache(timeout_seconds: float = _ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS) -> bool:
if _ensure_enabled_skills_cache().wait(timeout=timeout_seconds):
def warm_enabled_skills_cache(app_config: AppConfig | None = None, timeout_seconds: float = _ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS) -> bool:
if _ensure_enabled_skills_cache(app_config).wait(timeout=timeout_seconds):
return True
logger.warning("Timed out waiting %.1fs for enabled skills cache warm-up", timeout_seconds)
return False
def _get_enabled_skills():
def _get_enabled_skills(app_config: AppConfig | None = None):
with _enabled_skills_lock:
cached = _enabled_skills_cache
if cached is not None:
return list(cached)
_ensure_enabled_skills_cache()
_ensure_enabled_skills_cache(app_config)
return []
@@ -115,12 +117,12 @@ def _skill_mutability_label(category: str) -> str:
return "[custom, editable]" if category == "custom" else "[built-in]"
def clear_skills_system_prompt_cache() -> None:
_invalidate_enabled_skills_cache()
def clear_skills_system_prompt_cache(app_config: AppConfig | None = None) -> None:
_invalidate_enabled_skills_cache(app_config)
async def refresh_skills_system_prompt_cache_async() -> None:
await asyncio.to_thread(_invalidate_enabled_skills_cache().wait)
async def refresh_skills_system_prompt_cache_async(app_config: AppConfig | None = None) -> None:
await asyncio.to_thread(_invalidate_enabled_skills_cache(app_config).wait)
def _reset_skills_system_prompt_cache_state() -> None:
@@ -134,10 +136,10 @@ def _reset_skills_system_prompt_cache_state() -> None:
_enabled_skills_refresh_event.clear()
def _refresh_enabled_skills_cache() -> None:
def _refresh_enabled_skills_cache(app_config: AppConfig | None = None) -> None:
"""Backward-compatible test helper for direct synchronous reload."""
try:
skills = _load_enabled_skills_sync()
skills = _load_enabled_skills_sync(app_config)
except Exception:
logger.exception("Failed to load enabled skills for prompt injection")
skills = []
@@ -164,7 +166,7 @@ Skip simple one-off tasks.
"""
def _build_available_subagents_description(available_names: list[str], bash_available: bool) -> str:
def _build_available_subagents_description(available_names: list[str], bash_available: bool, app_config: AppConfig) -> str:
"""Dynamically build subagent type descriptions from registry.
Mirrors Codex's pattern where agent_type_description is dynamically generated
@@ -186,7 +188,7 @@ def _build_available_subagents_description(available_names: list[str], bash_avai
if name in builtin_descriptions:
lines.append(f"- **{name}**: {builtin_descriptions[name]}")
else:
config = get_subagent_config(name)
config = get_subagent_config(name, app_config)
if config is not None:
desc = config.description.split("\n")[0].strip() # First line only for brevity
lines.append(f"- **{name}**: {desc}")
@@ -194,22 +196,23 @@ def _build_available_subagents_description(available_names: list[str], bash_avai
return "\n".join(lines)
def _build_subagent_section(max_concurrent: int) -> str:
def _build_subagent_section(max_concurrent: int, app_config: AppConfig) -> str:
"""Build the subagent system prompt section with dynamic concurrency limit.
Args:
max_concurrent: Maximum number of concurrent subagent calls allowed per response.
app_config: Application config used to gate bash availability.
Returns:
Formatted subagent section string.
"""
n = max_concurrent
available_names = get_available_subagent_names()
available_names = get_available_subagent_names(app_config)
bash_available = "bash" in available_names
# Dynamically build subagent type descriptions from registry (aligned with Codex's
# agent_type_description pattern where all registered roles are listed in the tool spec).
available_subagents = _build_available_subagents_description(available_names, bash_available)
available_subagents = _build_available_subagents_description(available_names, bash_available, app_config)
direct_tool_examples = "bash, ls, read_file, web_search, etc." if bash_available else "ls, read_file, web_search, etc."
direct_execution_example = (
'# User asks: "Run the tests"\n# Thinking: Cannot decompose into parallel sub-tasks\n# → Execute directly\n\nbash("npm test") # Direct execution, not task()'
@@ -536,36 +539,34 @@ combined with a FastAPI gateway for REST API access [citation:FastAPI](https://f
"""
def _get_memory_context(agent_name: str | None = None) -> str:
def _get_memory_context(app_config: AppConfig, agent_name: str | None = None) -> str:
"""Get memory context for injection into system prompt.
Args:
agent_name: If provided, loads per-agent memory. If None, loads global memory.
Returns:
Formatted memory context string wrapped in XML tags, or empty string if disabled.
Returns an empty string when memory is disabled or the stored memory file
cannot be read/parsed. A corrupt memory.json degrades the prompt to
no-memory; it never kills the agent.
"""
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
from deerflow.runtime.user_context import get_effective_user_id
memory_config = app_config.memory
if not memory_config.enabled or not memory_config.injection_enabled:
return ""
try:
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
from deerflow.config.memory_config import get_memory_config
memory_data = get_memory_data(memory_config, agent_name, user_id=get_effective_user_id())
except (OSError, ValueError, UnicodeDecodeError):
logger.exception("Failed to load memory data for prompt injection")
return ""
config = get_memory_config()
if not config.enabled or not config.injection_enabled:
return ""
memory_content = format_memory_for_injection(memory_data, max_tokens=memory_config.max_injection_tokens)
if not memory_content.strip():
return ""
memory_data = get_memory_data(agent_name)
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
if not memory_content.strip():
return ""
return f"""<memory>
return f"""<memory>
{memory_content}
</memory>
"""
except Exception as e:
logger.error("Failed to load memory context: %s", e)
return ""
@lru_cache(maxsize=32)
@@ -600,19 +601,12 @@ You have access to skills that provide optimized workflows for specific tasks. E
</skill_system>"""
def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
def get_skills_prompt_section(app_config: AppConfig, available_skills: set[str] | None = None) -> str:
"""Generate the skills prompt section with available skills list."""
skills = _get_enabled_skills()
skills = _get_enabled_skills(app_config)
try:
from deerflow.config import get_app_config
config = get_app_config()
container_base_path = config.skills.container_path
skill_evolution_enabled = config.skill_evolution.enabled
except Exception:
container_base_path = "/mnt/skills"
skill_evolution_enabled = False
container_base_path = app_config.skills.container_path
skill_evolution_enabled = app_config.skill_evolution.enabled
if not skills and not skill_evolution_enabled:
return ""
@@ -636,7 +630,7 @@ def get_agent_soul(agent_name: str | None) -> str:
return ""
def get_deferred_tools_prompt_section() -> str:
def get_deferred_tools_prompt_section(app_config: AppConfig) -> str:
"""Generate <available-deferred-tools> block for the system prompt.
Lists only deferred tool names so the agent knows what exists
@@ -645,12 +639,7 @@ def get_deferred_tools_prompt_section() -> str:
"""
from deerflow.tools.builtins.tool_search import get_deferred_registry
try:
from deerflow.config import get_app_config
if not get_app_config().tool_search.enabled:
return ""
except Exception:
if not app_config.tool_search.enabled:
return ""
registry = get_deferred_registry()
@@ -661,15 +650,9 @@ def get_deferred_tools_prompt_section() -> str:
return f"<available-deferred-tools>\n{names}\n</available-deferred-tools>"
def _build_acp_section() -> str:
def _build_acp_section(app_config: AppConfig) -> str:
"""Build the ACP agent prompt section, only if ACP agents are configured."""
try:
from deerflow.config.acp_config import get_acp_agents
agents = get_acp_agents()
if not agents:
return ""
except Exception:
if not app_config.acp_agents:
return ""
return (
@@ -681,15 +664,9 @@ def _build_acp_section() -> str:
)
def _build_custom_mounts_section() -> str:
def _build_custom_mounts_section(app_config: AppConfig) -> str:
"""Build a prompt section for explicitly configured sandbox mounts."""
try:
from deerflow.config import get_app_config
mounts = get_app_config().sandbox.mounts or []
except Exception:
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
return ""
mounts = app_config.sandbox.mounts or []
if not mounts:
return ""
@@ -703,13 +680,20 @@ def _build_custom_mounts_section() -> str:
return f"\n**Custom Mounted Directories:**\n{mounts_list}\n- If the user needs files outside `/mnt/user-data`, use these absolute container paths directly when they match the requested directory"
def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagents: int = 3, *, agent_name: str | None = None, available_skills: set[str] | None = None) -> str:
def apply_prompt_template(
app_config: AppConfig,
subagent_enabled: bool = False,
max_concurrent_subagents: int = 3,
*,
agent_name: str | None = None,
available_skills: set[str] | None = None,
) -> str:
# Get memory context
memory_context = _get_memory_context(agent_name)
memory_context = _get_memory_context(app_config, agent_name)
# Include subagent section only if enabled (from runtime parameter)
n = max_concurrent_subagents
subagent_section = _build_subagent_section(n) if subagent_enabled else ""
subagent_section = _build_subagent_section(n, app_config) if subagent_enabled else ""
# Add subagent reminder to critical_reminders if enabled
subagent_reminder = (
@@ -730,14 +714,14 @@ def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagen
)
# Get skills section
skills_section = get_skills_prompt_section(available_skills)
skills_section = get_skills_prompt_section(app_config, available_skills)
# Get deferred tools section (tool_search)
deferred_tools_section = get_deferred_tools_prompt_section()
deferred_tools_section = get_deferred_tools_prompt_section(app_config)
# Build ACP agent section only if ACP agents are configured
acp_section = _build_acp_section()
custom_mounts_section = _build_custom_mounts_section()
acp_section = _build_acp_section(app_config)
custom_mounts_section = _build_custom_mounts_section(app_config)
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
# Format the prompt with dynamic skills and memory
@@ -7,11 +7,17 @@ from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Any
from deerflow.config.memory_config import get_memory_config
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
# Module-level config pointer set by the middleware that owns the queue.
# The queue runs on a background Timer thread where ``Runtime`` and FastAPI
# request context are not accessible; the enqueuer (which does have runtime
# context) is responsible for plumbing ``AppConfig`` through ``add()``.
@dataclass
class ConversationContext:
"""Context for a conversation to be processed for memory update."""
@@ -20,6 +26,7 @@ class ConversationContext:
messages: list[Any]
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
agent_name: str | None = None
user_id: str | None = None
correction_detected: bool = False
reinforcement_detected: bool = False
@@ -30,10 +37,21 @@ class MemoryUpdateQueue:
This queue collects conversation contexts and processes them after
a configurable debounce period. Multiple conversations received within
the debounce window are batched together.
The queue captures an ``AppConfig`` reference at construction time and
reuses it for the MemoryUpdater it spawns. Callers must construct a
fresh queue when the config changes rather than reaching into a global.
"""
def __init__(self):
"""Initialize the memory update queue."""
def __init__(self, app_config: AppConfig):
"""Initialize the memory update queue.
Args:
app_config: Application config. The queue reads its own
``memory`` section for debounce timing and hands the full
config to :class:`MemoryUpdater`.
"""
self._app_config = app_config
self._queue: list[ConversationContext] = []
self._lock = threading.Lock()
self._timer: threading.Timer | None = None
@@ -44,19 +62,12 @@ class MemoryUpdateQueue:
thread_id: str,
messages: list[Any],
agent_name: str | None = None,
user_id: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> None:
"""Add a conversation to the update queue.
Args:
thread_id: The thread ID.
messages: The conversation messages.
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
"""
config = get_memory_config()
"""Add a conversation to the update queue."""
config = self._app_config.memory
if not config.enabled:
return
@@ -65,6 +76,7 @@ class MemoryUpdateQueue:
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
user_id=user_id,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
@@ -77,11 +89,12 @@ class MemoryUpdateQueue:
thread_id: str,
messages: list[Any],
agent_name: str | None = None,
user_id: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> None:
"""Add a conversation and start processing immediately in the background."""
config = get_memory_config()
config = self._app_config.memory
if not config.enabled:
return
@@ -90,6 +103,7 @@ class MemoryUpdateQueue:
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
user_id=user_id,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
@@ -103,6 +117,7 @@ class MemoryUpdateQueue:
thread_id: str,
messages: list[Any],
agent_name: str | None,
user_id: str | None = None,
correction_detected: bool,
reinforcement_detected: bool,
) -> None:
@@ -116,6 +131,7 @@ class MemoryUpdateQueue:
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
user_id=user_id,
correction_detected=merged_correction_detected,
reinforcement_detected=merged_reinforcement_detected,
)
@@ -125,7 +141,7 @@ class MemoryUpdateQueue:
def _reset_timer(self) -> None:
"""Reset the debounce timer."""
config = get_memory_config()
config = self._app_config.memory
self._schedule_timer(config.debounce_seconds)
logger.debug("Memory update timer set for %ss", config.debounce_seconds)
@@ -165,7 +181,7 @@ class MemoryUpdateQueue:
logger.info("Processing %d queued memory updates", len(contexts_to_process))
try:
updater = MemoryUpdater()
updater = MemoryUpdater(self._app_config)
for context in contexts_to_process:
try:
@@ -176,6 +192,7 @@ class MemoryUpdateQueue:
agent_name=context.agent_name,
correction_detected=context.correction_detected,
reinforcement_detected=context.reinforcement_detected,
user_id=context.user_id,
)
if success:
logger.info("Memory updated successfully for thread %s", context.thread_id)
@@ -236,31 +253,35 @@ class MemoryUpdateQueue:
return self._processing
# Global singleton instance
_memory_queue: MemoryUpdateQueue | None = None
# Queues keyed by ``id(AppConfig)`` so tests and multi-client setups with
# distinct configs do not share a debounce queue.
_memory_queues: dict[int, MemoryUpdateQueue] = {}
_queue_lock = threading.Lock()
def get_memory_queue() -> MemoryUpdateQueue:
"""Get the global memory update queue singleton.
Returns:
The memory update queue instance.
"""
global _memory_queue
def get_memory_queue(app_config: AppConfig) -> MemoryUpdateQueue:
"""Get or create the memory update queue for the given app config."""
key = id(app_config)
with _queue_lock:
if _memory_queue is None:
_memory_queue = MemoryUpdateQueue()
return _memory_queue
queue = _memory_queues.get(key)
if queue is None:
queue = MemoryUpdateQueue(app_config)
_memory_queues[key] = queue
return queue
def reset_memory_queue() -> None:
"""Reset the global memory queue.
def reset_memory_queue(app_config: AppConfig | None = None) -> None:
"""Reset memory queue(s).
This is useful for testing.
Pass an ``app_config`` to reset only its queue, or omit to reset all
(useful at test teardown).
"""
global _memory_queue
with _queue_lock:
if _memory_queue is not None:
_memory_queue.clear()
_memory_queue = None
if app_config is not None:
queue = _memory_queues.pop(id(app_config), None)
if queue is not None:
queue.clear()
return
for queue in _memory_queues.values():
queue.clear()
_memory_queues.clear()
@@ -10,7 +10,7 @@ from pathlib import Path
from typing import Any
from deerflow.config.agents_config import AGENT_NAME_PATTERN
from deerflow.config.memory_config import get_memory_config
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.paths import get_paths
logger = logging.getLogger(__name__)
@@ -44,17 +44,17 @@ class MemoryStorage(abc.ABC):
"""Abstract base class for memory storage providers."""
@abc.abstractmethod
def load(self, agent_name: str | None = None) -> dict[str, Any]:
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Load memory data for the given agent."""
pass
@abc.abstractmethod
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Force reload memory data for the given agent."""
pass
@abc.abstractmethod
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
"""Save memory data for the given agent."""
pass
@@ -62,11 +62,18 @@ class MemoryStorage(abc.ABC):
class FileMemoryStorage(MemoryStorage):
"""File-based memory storage provider."""
def __init__(self):
"""Initialize the file memory storage."""
# Per-agent memory cache: keyed by agent_name (None = global)
def __init__(self, memory_config: MemoryConfig):
"""Initialize the file memory storage.
Args:
memory_config: Memory configuration (storage_path etc.). Stored on
the instance so per-request lookups don't need to reach for
ambient state.
"""
self._memory_config = memory_config
# Per-user/agent memory cache: keyed by (user_id, agent_name) tuple (None = global)
# Value: (memory_data, file_mtime)
self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
self._memory_cache: dict[tuple[str | None, str | None], tuple[dict[str, Any], float | None]] = {}
# Guards all reads and writes to _memory_cache across concurrent callers.
self._cache_lock = threading.Lock()
@@ -81,21 +88,28 @@ class FileMemoryStorage(MemoryStorage):
if not AGENT_NAME_PATTERN.match(agent_name):
raise ValueError(f"Invalid agent name {agent_name!r}: names must match {AGENT_NAME_PATTERN.pattern}")
def _get_memory_file_path(self, agent_name: str | None = None) -> Path:
def _get_memory_file_path(self, agent_name: str | None = None, *, user_id: str | None = None) -> Path:
"""Get the path to the memory file."""
config = self._memory_config
if user_id is not None:
if agent_name is not None:
self._validate_agent_name(agent_name)
return get_paths().user_agent_memory_file(user_id, agent_name)
if config.storage_path and Path(config.storage_path).is_absolute():
return Path(config.storage_path)
return get_paths().user_memory_file(user_id)
# Legacy: no user_id
if agent_name is not None:
self._validate_agent_name(agent_name)
return get_paths().agent_memory_file(agent_name)
config = get_memory_config()
if config.storage_path:
p = Path(config.storage_path)
return p if p.is_absolute() else get_paths().base_dir / p
return get_paths().memory_file
def _load_memory_from_file(self, agent_name: str | None = None) -> dict[str, Any]:
def _load_memory_from_file(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Load memory data from file."""
file_path = self._get_memory_file_path(agent_name)
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
if not file_path.exists():
return create_empty_memory()
@@ -108,44 +122,46 @@ class FileMemoryStorage(MemoryStorage):
logger.warning("Failed to load memory file: %s", e)
return create_empty_memory()
def load(self, agent_name: str | None = None) -> dict[str, Any]:
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Load memory data (cached with file modification time check)."""
file_path = self._get_memory_file_path(agent_name)
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
try:
current_mtime = file_path.stat().st_mtime if file_path.exists() else None
except OSError:
current_mtime = None
cache_key = (user_id, agent_name)
with self._cache_lock:
cached = self._memory_cache.get(agent_name)
cached = self._memory_cache.get(cache_key)
if cached is not None and cached[1] == current_mtime:
return cached[0]
memory_data = self._load_memory_from_file(agent_name)
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, current_mtime)
self._memory_cache[cache_key] = (memory_data, current_mtime)
return memory_data
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Reload memory data from file, forcing cache invalidation."""
file_path = self._get_memory_file_path(agent_name)
memory_data = self._load_memory_from_file(agent_name)
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
try:
mtime = file_path.stat().st_mtime if file_path.exists() else None
except OSError:
mtime = None
cache_key = (user_id, agent_name)
with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, mtime)
self._memory_cache[cache_key] = (memory_data, mtime)
return memory_data
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
"""Save memory data to file and update cache."""
file_path = self._get_memory_file_path(agent_name)
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
try:
file_path.parent.mkdir(parents=True, exist_ok=True)
@@ -165,8 +181,9 @@ class FileMemoryStorage(MemoryStorage):
except OSError:
mtime = None
cache_key = (user_id, agent_name)
with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, mtime)
self._memory_cache[cache_key] = (memory_data, mtime)
logger.info("Memory saved to %s", file_path)
return True
except OSError as e:
@@ -174,23 +191,31 @@ class FileMemoryStorage(MemoryStorage):
return False
_storage_instance: MemoryStorage | None = None
# Instances keyed by (storage_class_path, id(memory_config)) so tests can
# construct isolated storages and multi-client setups with different configs
# don't collide on a single process-wide singleton.
_storage_instances: dict[tuple[str, int], MemoryStorage] = {}
_storage_lock = threading.Lock()
def get_memory_storage() -> MemoryStorage:
"""Get the configured memory storage instance."""
global _storage_instance
if _storage_instance is not None:
return _storage_instance
def get_memory_storage(memory_config: MemoryConfig) -> MemoryStorage:
"""Get the configured memory storage instance.
Caches one instance per ``(storage_class, memory_config)`` pair. In
single-config deployments this collapses to one instance; in multi-client
or test scenarios each config gets its own storage.
"""
key = (memory_config.storage_class, id(memory_config))
existing = _storage_instances.get(key)
if existing is not None:
return existing
with _storage_lock:
if _storage_instance is not None:
return _storage_instance
config = get_memory_config()
storage_class_path = config.storage_class
existing = _storage_instances.get(key)
if existing is not None:
return existing
storage_class_path = memory_config.storage_class
try:
module_path, class_name = storage_class_path.rsplit(".", 1)
import importlib
@@ -204,13 +229,14 @@ def get_memory_storage() -> MemoryStorage:
if not issubclass(storage_class, MemoryStorage):
raise TypeError(f"Configured memory storage '{storage_class_path}' is not a subclass of MemoryStorage")
_storage_instance = storage_class()
instance = storage_class(memory_config)
except Exception as e:
logger.error(
"Failed to load memory storage %s, falling back to FileMemoryStorage: %s",
storage_class_path,
e,
)
_storage_instance = FileMemoryStorage()
instance = FileMemoryStorage(memory_config)
return _storage_instance
_storage_instances[key] = instance
return instance
@@ -5,12 +5,19 @@ from __future__ import annotations
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
from deerflow.agents.memory.queue import get_memory_queue
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
from deerflow.config.memory_config import get_memory_config
from deerflow.config.app_config import AppConfig
def memory_flush_hook(event: SummarizationEvent) -> None:
"""Flush messages about to be summarized into the memory queue."""
if not get_memory_config().enabled or not event.thread_id:
"""Flush messages about to be summarized into the memory queue.
Reads ``AppConfig`` from disk on every invocation. This hook is fired by
``SummarizationMiddleware`` which has no ergonomic way to thread an
explicit ``app_config`` through; ``AppConfig.from_file()`` is a pure load
so the cost is acceptable for this rare pre-summarization callback.
"""
app_config = AppConfig.from_file()
if not app_config.memory.enabled or not event.thread_id:
return
filtered_messages = filter_messages_for_memory(list(event.messages_to_summarize))
@@ -21,7 +28,7 @@ def memory_flush_hook(event: SummarizationEvent) -> None:
correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
queue = get_memory_queue()
queue = get_memory_queue(app_config)
queue.add_nowait(
thread_id=event.thread_id,
messages=filtered_messages,
@@ -21,7 +21,8 @@ from deerflow.agents.memory.storage import (
get_memory_storage,
utc_now_iso_z,
)
from deerflow.config.memory_config import get_memory_config
from deerflow.config.app_config import AppConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@@ -38,44 +39,33 @@ def _create_empty_memory() -> dict[str, Any]:
return create_empty_memory()
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
"""Backward-compatible wrapper around the configured memory storage save path."""
return get_memory_storage().save(memory_data, agent_name)
def _save_memory_to_file(memory_config: MemoryConfig, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
"""Save via the configured memory storage."""
return get_memory_storage(memory_config).save(memory_data, agent_name, user_id=user_id)
def get_memory_data(agent_name: str | None = None) -> dict[str, Any]:
def get_memory_data(memory_config: MemoryConfig, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Get the current memory data via storage provider."""
return get_memory_storage().load(agent_name)
return get_memory_storage(memory_config).load(agent_name, user_id=user_id)
def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]:
def reload_memory_data(memory_config: MemoryConfig, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Reload memory data via storage provider."""
return get_memory_storage().reload(agent_name)
return get_memory_storage(memory_config).reload(agent_name, user_id=user_id)
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None) -> dict[str, Any]:
"""Persist imported memory data via storage provider.
Args:
memory_data: Full memory payload to persist.
agent_name: If provided, imports into per-agent memory.
Returns:
The saved memory data after storage normalization.
Raises:
OSError: If persisting the imported memory fails.
"""
storage = get_memory_storage()
if not storage.save(memory_data, agent_name):
def import_memory_data(memory_config: MemoryConfig, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Persist imported memory data via storage provider."""
storage = get_memory_storage(memory_config)
if not storage.save(memory_data, agent_name, user_id=user_id):
raise OSError("Failed to save imported memory data")
return storage.load(agent_name)
return storage.load(agent_name, user_id=user_id)
def clear_memory_data(agent_name: str | None = None) -> dict[str, Any]:
def clear_memory_data(memory_config: MemoryConfig, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Clear all stored memory data and persist an empty structure."""
cleared_memory = create_empty_memory()
if not _save_memory_to_file(cleared_memory, agent_name):
if not _save_memory_to_file(memory_config, cleared_memory, agent_name, user_id=user_id):
raise OSError("Failed to save cleared memory data")
return cleared_memory
@@ -88,10 +78,13 @@ def _validate_confidence(confidence: float) -> float:
def create_memory_fact(
memory_config: MemoryConfig,
content: str,
category: str = "context",
confidence: float = 0.5,
agent_name: str | None = None,
*,
user_id: str | None = None,
) -> dict[str, Any]:
"""Create a new fact and persist the updated memory data."""
normalized_content = content.strip()
@@ -101,7 +94,7 @@ def create_memory_fact(
normalized_category = category.strip() or "context"
validated_confidence = _validate_confidence(confidence)
now = utc_now_iso_z()
memory_data = get_memory_data(agent_name)
memory_data = get_memory_data(memory_config, agent_name, user_id=user_id)
updated_memory = dict(memory_data)
facts = list(memory_data.get("facts", []))
facts.append(
@@ -116,15 +109,15 @@ def create_memory_fact(
)
updated_memory["facts"] = facts
if not _save_memory_to_file(updated_memory, agent_name):
if not _save_memory_to_file(memory_config, updated_memory, agent_name, user_id=user_id):
raise OSError("Failed to save memory data after creating fact")
return updated_memory
def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str, Any]:
def delete_memory_fact(memory_config: MemoryConfig, fact_id: str, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Delete a fact by its id and persist the updated memory data."""
memory_data = get_memory_data(agent_name)
memory_data = get_memory_data(memory_config, agent_name, user_id=user_id)
facts = memory_data.get("facts", [])
updated_facts = [fact for fact in facts if fact.get("id") != fact_id]
if len(updated_facts) == len(facts):
@@ -133,21 +126,24 @@ def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str,
updated_memory = dict(memory_data)
updated_memory["facts"] = updated_facts
if not _save_memory_to_file(updated_memory, agent_name):
if not _save_memory_to_file(memory_config, updated_memory, agent_name, user_id=user_id):
raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'")
return updated_memory
def update_memory_fact(
memory_config: MemoryConfig,
fact_id: str,
content: str | None = None,
category: str | None = None,
confidence: float | None = None,
agent_name: str | None = None,
*,
user_id: str | None = None,
) -> dict[str, Any]:
"""Update an existing fact and persist the updated memory data."""
memory_data = get_memory_data(agent_name)
memory_data = get_memory_data(memory_config, agent_name, user_id=user_id)
updated_memory = dict(memory_data)
updated_facts: list[dict[str, Any]] = []
found = False
@@ -174,7 +170,7 @@ def update_memory_fact(
updated_memory["facts"] = updated_facts
if not _save_memory_to_file(updated_memory, agent_name):
if not _save_memory_to_file(memory_config, updated_memory, agent_name, user_id=user_id):
raise OSError(f"Failed to save memory data after updating fact '{fact_id}'")
return updated_memory
@@ -299,19 +295,25 @@ def _fact_content_key(content: Any) -> str | None:
class MemoryUpdater:
"""Updates memory using LLM based on conversation context."""
def __init__(self, model_name: str | None = None):
def __init__(self, app_config: AppConfig, model_name: str | None = None):
"""Initialize the memory updater.
Args:
app_config: Application config (the updater needs both ``memory``
section for behavior and the full config for ``create_chat_model``).
model_name: Optional model name to use. If None, uses config or default.
"""
self._app_config = app_config
self._model_name = model_name
@property
def _memory_config(self) -> MemoryConfig:
return self._app_config.memory
def _get_model(self):
"""Get the model for memory updates."""
config = get_memory_config()
model_name = self._model_name or config.model_name
return create_chat_model(name=model_name, thinking_enabled=False)
model_name = self._model_name or self._memory_config.model_name
return create_chat_model(name=model_name, thinking_enabled=False, app_config=self._app_config)
def _build_correction_hint(
self,
@@ -344,13 +346,14 @@ class MemoryUpdater:
agent_name: str | None,
correction_detected: bool,
reinforcement_detected: bool,
user_id: str | None = None,
) -> tuple[dict[str, Any], str] | None:
"""Load memory and build the update prompt for a conversation."""
config = get_memory_config()
config = self._memory_config
if not config.enabled or not messages:
return None
current_memory = get_memory_data(agent_name)
current_memory = get_memory_data(config, agent_name, user_id=user_id)
conversation_text = format_conversation_for_update(messages)
if not conversation_text.strip():
return None
@@ -372,6 +375,7 @@ class MemoryUpdater:
response_content: Any,
thread_id: str | None,
agent_name: str | None,
user_id: str | None = None,
) -> bool:
"""Parse the model response, apply updates, and persist memory."""
response_text = _extract_text(response_content).strip()
@@ -385,7 +389,7 @@ class MemoryUpdater:
# cannot corrupt the still-cached original object reference.
updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id)
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
return get_memory_storage().save(updated_memory, agent_name)
return get_memory_storage(self._memory_config).save(updated_memory, agent_name, user_id=user_id)
async def aupdate_memory(
self,
@@ -394,6 +398,7 @@ class MemoryUpdater:
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
user_id: str | None = None,
) -> bool:
"""Update memory asynchronously based on conversation messages."""
try:
@@ -403,6 +408,7 @@ class MemoryUpdater:
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
user_id=user_id,
)
if prepared is None:
return False
@@ -416,6 +422,7 @@ class MemoryUpdater:
response_content=response.content,
thread_id=thread_id,
agent_name=agent_name,
user_id=user_id,
)
except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e)
@@ -431,6 +438,7 @@ class MemoryUpdater:
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
user_id: str | None = None,
) -> bool:
"""Synchronously update memory via the async updater path.
@@ -440,19 +448,83 @@ class MemoryUpdater:
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
user_id: If provided, scopes memory to a specific user.
Returns:
True if update was successful, False otherwise.
"""
return _run_async_update_sync(
self.aupdate_memory(
messages=messages,
thread_id=thread_id,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
config = self._memory_config
if not config.enabled:
return False
if not messages:
return False
try:
# Get current memory
current_memory = get_memory_data(config, agent_name, user_id=user_id)
# Format conversation for prompt
conversation_text = format_conversation_for_update(messages)
if not conversation_text.strip():
return False
# Build prompt
correction_hint = ""
if correction_detected:
correction_hint = (
"IMPORTANT: Explicit correction signals were detected in this conversation. "
"Pay special attention to what the agent got wrong, what the user corrected, "
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
if reinforcement_detected:
reinforcement_hint = (
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
"The user explicitly confirmed the agent's approach was correct or helpful. "
"Record the confirmed approach, style, or preference as a fact with category "
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
)
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
conversation=conversation_text,
correction_hint=correction_hint,
)
)
# Call LLM
model = self._get_model()
response = model.invoke(prompt)
response_text = _extract_text(response.content).strip()
# Parse response
# Remove markdown code blocks if present
if response_text.startswith("```"):
lines = response_text.split("\n")
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
update_data = json.loads(response_text)
# Apply updates
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
# Strip file-upload mentions from all summaries before saving.
# Uploaded files are session-scoped and won't exist in future sessions,
# so recording upload events in long-term memory causes the agent to
# try (and fail) to locate those files in subsequent conversations.
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
# Save
return get_memory_storage(config).save(updated_memory, agent_name, user_id=user_id)
except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e)
return False
except Exception as e:
logger.exception("Memory update failed: %s", e)
return False
def _apply_updates(
self,
@@ -470,7 +542,7 @@ class MemoryUpdater:
Returns:
Updated memory data.
"""
config = get_memory_config()
config = self._memory_config
now = utc_now_iso_z()
# Update user sections
@@ -547,6 +619,7 @@ def update_memory_from_conversation(
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
user_id: str | None = None,
) -> bool:
"""Convenience function to update memory from a conversation.
@@ -556,9 +629,10 @@ def update_memory_from_conversation(
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
user_id: If provided, scopes memory to a specific user.
Returns:
True if successful, False otherwise.
"""
updater = MemoryUpdater()
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected, user_id=user_id)
@@ -20,7 +20,7 @@ from langchain.agents.middleware.types import (
from langchain_core.messages import AIMessage
from langgraph.errors import GraphBubbleUp
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
@@ -78,7 +78,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
# Load Circuit Breaker configs from app config if available, fall back to defaults
try:
app_config = get_app_config()
app_config = AppConfig.from_file()
self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold
self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec
except (FileNotFoundError, RuntimeError):
@@ -25,6 +25,8 @@ from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
from deerflow.config.deer_flow_context import DeerFlowContext
logger = logging.getLogger(__name__)
# Defaults — can be overridden via constructor
@@ -181,12 +183,9 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
def _get_thread_id(self, runtime: Runtime) -> str:
def _get_thread_id(self, runtime: Runtime[DeerFlowContext]) -> str:
"""Extract thread_id from runtime context for per-thread tracking."""
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id:
return thread_id
return "default"
return runtime.context.thread_id or "default"
def _evict_if_needed(self) -> None:
"""Evict least recently used threads if over the limit.
@@ -367,11 +366,11 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
return None
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
def after_model(self, state: AgentState, runtime: Runtime[DeerFlowContext]) -> dict | None:
return self._apply(state, runtime)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
async def aafter_model(self, state: AgentState, runtime: Runtime[DeerFlowContext]) -> dict | None:
return self._apply(state, runtime)
def reset(self, thread_id: str | None = None) -> None:
@@ -5,12 +5,12 @@ from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
from deerflow.agents.memory.queue import get_memory_queue
from deerflow.config.memory_config import get_memory_config
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__)
@@ -43,7 +43,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
self._agent_name = agent_name
@override
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None:
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
"""Queue conversation for memory update after agent completes.
Args:
@@ -53,15 +53,11 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
Returns:
None (no state changes needed from this middleware).
"""
config = get_memory_config()
if not config.enabled:
memory_config = runtime.context.app_config.memory
if not memory_config.enabled:
return None
# Get thread ID from runtime context first, then fall back to LangGraph's configurable metadata
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id is None:
config_data = get_config()
thread_id = config_data.get("configurable", {}).get("thread_id")
thread_id = runtime.context.thread_id
if not thread_id:
logger.debug("No thread_id in context, skipping memory update")
return None
@@ -86,11 +82,16 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
# Queue the filtered conversation for memory update
correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
queue = get_memory_queue()
# Capture user_id at enqueue time while the request context is still alive.
# threading.Timer fires on a different thread where ContextVar values are not
# propagated, so we must store user_id explicitly in ConversationContext.
user_id = get_effective_user_id()
queue = get_memory_queue(runtime.context.app_config)
queue.add(
thread_id=thread_id,
messages=filtered_messages,
agent_name=self._agent_name,
user_id=user_id,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
@@ -3,11 +3,12 @@ from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.agents.thread_state import ThreadDataState
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__)
@@ -46,50 +47,50 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
self._paths = Paths(base_dir) if base_dir else get_paths()
self._lazy_init = lazy_init
def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
def _get_thread_paths(self, thread_id: str, user_id: str | None = None) -> dict[str, str]:
"""Get the paths for a thread's data directories.
Args:
thread_id: The thread ID.
user_id: Optional user ID for per-user path isolation.
Returns:
Dictionary with workspace_path, uploads_path, and outputs_path.
"""
return {
"workspace_path": str(self._paths.sandbox_work_dir(thread_id)),
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id)),
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id)),
"workspace_path": str(self._paths.sandbox_work_dir(thread_id, user_id=user_id)),
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id, user_id=user_id)),
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id, user_id=user_id)),
}
def _create_thread_directories(self, thread_id: str) -> dict[str, str]:
def _create_thread_directories(self, thread_id: str, user_id: str | None = None) -> dict[str, str]:
"""Create the thread data directories.
Args:
thread_id: The thread ID.
user_id: Optional user ID for per-user path isolation.
Returns:
Dictionary with the created directory paths.
"""
self._paths.ensure_thread_dirs(thread_id)
return self._get_thread_paths(thread_id)
self._paths.ensure_thread_dirs(thread_id, user_id=user_id)
return self._get_thread_paths(thread_id, user_id=user_id)
@override
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
context = runtime.context or {}
thread_id = context.get("thread_id")
if thread_id is None:
config = get_config()
thread_id = config.get("configurable", {}).get("thread_id")
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
thread_id = runtime.context.thread_id
if thread_id is None:
if not thread_id:
raise ValueError("Thread ID is required in runtime context or config.configurable")
user_id = get_effective_user_id()
if self._lazy_init:
# Lazy initialization: only compute paths, don't create directories
paths = self._get_thread_paths(thread_id)
paths = self._get_thread_paths(thread_id, user_id=user_id)
else:
# Eager initialization: create directories immediately
paths = self._create_thread_directories(thread_id)
paths = self._create_thread_directories(thread_id, user_id=user_id)
logger.debug("Created thread data directories for thread %s", thread_id)
return {
@@ -2,13 +2,16 @@
import logging
import re
from typing import NotRequired, override
from typing import Any, NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.config.title_config import get_title_config
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.title_config import TitleConfig
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@@ -44,10 +47,9 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
return ""
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
def _should_generate_title(self, state: TitleMiddlewareState, title_config: TitleConfig) -> bool:
"""Check if we should generate a title for this thread."""
config = get_title_config()
if not config.enabled:
if not title_config.enabled:
return False
# Check if thread already has a title in state
@@ -66,12 +68,11 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
# Generate title after first complete exchange
return len(user_messages) == 1 and len(assistant_messages) >= 1
def _build_title_prompt(self, state: TitleMiddlewareState) -> tuple[str, str]:
def _build_title_prompt(self, state: TitleMiddlewareState, title_config: TitleConfig) -> tuple[str, str]:
"""Extract user/assistant messages and build the title prompt.
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
"""
config = get_title_config()
messages = state.get("messages", [])
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
@@ -80,8 +81,8 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
user_msg = self._normalize_content(user_msg_content)
assistant_msg = self._strip_think_tags(self._normalize_content(assistant_msg_content))
prompt = config.prompt_template.format(
max_words=config.max_words,
prompt = title_config.prompt_template.format(
max_words=title_config.max_words,
user_msg=user_msg[:500],
assistant_msg=assistant_msg[:500],
)
@@ -91,54 +92,66 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
"""Remove <think>...</think> blocks emitted by reasoning models (e.g. minimax, DeepSeek-R1)."""
return re.sub(r"<think>[\s\S]*?</think>", "", text, flags=re.IGNORECASE).strip()
def _parse_title(self, content: object) -> str:
def _parse_title(self, content: object, title_config: TitleConfig) -> str:
"""Normalize model output into a clean title string."""
config = get_title_config()
title_content = self._normalize_content(content)
title_content = self._strip_think_tags(title_content)
title = title_content.strip().strip('"').strip("'")
return title[: config.max_chars] if len(title) > config.max_chars else title
return title[: title_config.max_chars] if len(title) > title_config.max_chars else title
def _fallback_title(self, user_msg: str) -> str:
config = get_title_config()
fallback_chars = min(config.max_chars, 50)
def _fallback_title(self, user_msg: str, title_config: TitleConfig) -> str:
fallback_chars = min(title_config.max_chars, 50)
if len(user_msg) > fallback_chars:
return user_msg[:fallback_chars].rstrip() + "..."
return user_msg if user_msg else "New Conversation"
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
def _get_runnable_config(self) -> dict[str, Any]:
"""Inherit the parent RunnableConfig and add middleware tag.
This ensures RunJournal identifies LLM calls from this middleware
as ``middleware:title`` instead of ``lead_agent``.
"""
try:
parent = get_config()
except Exception:
parent = {}
config = {**parent}
config["tags"] = [*(config.get("tags") or []), "middleware:title"]
return config
def _generate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig) -> dict | None:
"""Generate a local fallback title without blocking on an LLM call."""
if not self._should_generate_title(state):
if not self._should_generate_title(state, title_config):
return None
_, user_msg = self._build_title_prompt(state)
return {"title": self._fallback_title(user_msg)}
_, user_msg = self._build_title_prompt(state, title_config)
return {"title": self._fallback_title(user_msg, title_config)}
async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None:
async def _agenerate_title_result(self, state: TitleMiddlewareState, app_config: AppConfig) -> dict | None:
"""Generate a title asynchronously and fall back locally on failure."""
if not self._should_generate_title(state):
title_config = app_config.title
if not self._should_generate_title(state, title_config):
return None
config = get_title_config()
prompt, user_msg = self._build_title_prompt(state)
prompt, user_msg = self._build_title_prompt(state, title_config)
try:
if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False)
if title_config.model_name:
model = create_chat_model(name=title_config.model_name, thinking_enabled=False, app_config=app_config)
else:
model = create_chat_model(thinking_enabled=False)
response = await model.ainvoke(prompt, config={"run_name": "title_agent"})
title = self._parse_title(response.content)
model = create_chat_model(thinking_enabled=False, app_config=app_config)
response = await model.ainvoke(prompt, config=self._get_runnable_config())
title = self._parse_title(response.content, title_config)
if title:
return {"title": title}
except Exception:
logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
return {"title": self._fallback_title(user_msg)}
return {"title": self._fallback_title(user_msg, title_config)}
@override
def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
return self._generate_title_result(state)
def after_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
return self._generate_title_result(state, runtime.context.app_config.title)
@override
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
return await self._agenerate_title_result(state)
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
return await self._agenerate_title_result(state, runtime.context.app_config)
@@ -1,8 +1,10 @@
"""Tool error handling middleware and shared runtime middleware builders."""
from __future__ import annotations
import logging
from collections.abc import Awaitable, Callable
from typing import override
from typing import TYPE_CHECKING, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
@@ -11,6 +13,9 @@ from langgraph.errors import GraphBubbleUp
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
@@ -67,6 +72,7 @@ class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
def _build_runtime_middlewares(
*,
app_config: "AppConfig",
include_uploads: bool,
include_dangling_tool_call_patch: bool,
lazy_init: bool = True,
@@ -94,9 +100,7 @@ def _build_runtime_middlewares(
middlewares.append(LLMErrorHandlingMiddleware())
# Guardrail middleware (if configured)
from deerflow.config.guardrails_config import get_guardrails_config
guardrails_config = get_guardrails_config()
guardrails_config = app_config.guardrails
if guardrails_config.enabled and guardrails_config.provider:
import inspect
@@ -125,9 +129,10 @@ def _build_runtime_middlewares(
return middlewares
def build_lead_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
def build_lead_runtime_middlewares(*, app_config: "AppConfig", lazy_init: bool = True) -> list[AgentMiddleware]:
"""Middlewares shared by lead agent runtime before lead-only middlewares."""
return _build_runtime_middlewares(
app_config=app_config,
include_uploads=True,
include_dangling_tool_call_patch=True,
lazy_init=lazy_init,
@@ -9,7 +9,9 @@ from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.utils.file_conversion import extract_outline
logger = logging.getLogger(__name__)
@@ -184,7 +186,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
return files if files else None
@override
def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime) -> dict | None:
def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
"""Inject uploaded files information before agent execution.
New files come from the current message's additional_kwargs.files.
@@ -213,15 +215,8 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
return None
# Resolve uploads directory for existence checks
thread_id = (runtime.context or {}).get("thread_id")
if thread_id is None:
try:
from langgraph.config import get_config
thread_id = get_config().get("configurable", {}).get("thread_id")
except RuntimeError:
pass # get_config() raises outside a runnable context (e.g. unit tests)
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
thread_id = runtime.context.thread_id
uploads_dir = self._paths.sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) if thread_id else None
# Get newly uploaded files from the current message's additional_kwargs.files
new_files = self._files_from_kwargs(last_message, uploads_dir) or []
+72 -44
View File
@@ -36,10 +36,12 @@ from deerflow.agents.lead_agent.agent import _build_middlewares
from deerflow.agents.lead_agent.prompt import apply_prompt_template
from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import AGENT_NAME_PATTERN
from deerflow.config.app_config import get_app_config, reload_app_config
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.paths import get_paths
from deerflow.models import create_chat_model
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.skills.installer import install_skill_from_archive
from deerflow.uploads.manager import (
claim_unique_filename,
@@ -115,6 +117,7 @@ class DeerFlowClient:
config_path: str | None = None,
checkpointer=None,
*,
config: AppConfig | None = None,
model_name: str | None = None,
thinking_enabled: bool = True,
subagent_enabled: bool = False,
@@ -129,9 +132,14 @@ class DeerFlowClient:
Args:
config_path: Path to config.yaml. Uses default resolution if None.
Ignored when ``config`` is provided.
checkpointer: LangGraph checkpointer instance for state persistence.
Required for multi-turn conversations on the same thread_id.
Without a checkpointer, each call is stateless.
config: Optional pre-constructed AppConfig. When provided, it takes
precedence over ``config_path`` and no file is read. Enables
multi-client isolation: two clients with different configs can
coexist in the same process without touching process-global state.
model_name: Override the default model name from config.
thinking_enabled: Enable model's extended thinking.
subagent_enabled: Enable subagent delegation.
@@ -140,9 +148,18 @@ class DeerFlowClient:
available_skills: Optional set of skill names to make available. If None (default), all scanned skills are available.
middlewares: Optional list of custom middlewares to inject into the agent.
"""
if config_path is not None:
reload_app_config(config_path)
self._app_config = get_app_config()
# Constructor-captured config: the client owns its AppConfig for its lifetime.
# Multiple clients with different configs do not contend.
#
# Priority: explicit ``config=`` > explicit ``config_path=`` > ``AppConfig.from_file()``
# with default path resolution. There is no ambient global fallback; if
# config.yaml cannot be located, ``from_file`` raises loudly.
if config is not None:
self._app_config = config
elif config_path is not None:
self._app_config = AppConfig.from_file(config_path)
else:
self._app_config = AppConfig.from_file()
if agent_name is not None and not AGENT_NAME_PATTERN.match(agent_name):
raise ValueError(f"Invalid agent name '{agent_name}'. Must match pattern: {AGENT_NAME_PATTERN.pattern}")
@@ -170,6 +187,15 @@ class DeerFlowClient:
self._agent = None
self._agent_config_key = None
def _reload_config(self) -> None:
"""Reload config from file and refresh the cached reference.
Only the client's own ``_app_config`` is rebuilt. Other clients
and the process-global are untouched, so multi-client coexistence
survives reload.
"""
self._app_config = AppConfig.from_file()
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@@ -227,10 +253,11 @@ class DeerFlowClient:
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
kwargs: dict[str, Any] = {
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
"model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=self._app_config),
"tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled),
"middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares),
"middleware": _build_middlewares(self._app_config, config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares),
"system_prompt": apply_prompt_template(
self._app_config,
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
agent_name=self._agent_name,
@@ -240,9 +267,9 @@ class DeerFlowClient:
}
checkpointer = self._checkpointer
if checkpointer is None:
from deerflow.agents.checkpointer import get_checkpointer
from deerflow.runtime.checkpointer import get_checkpointer
checkpointer = get_checkpointer()
checkpointer = get_checkpointer(self._app_config)
if checkpointer is not None:
kwargs["checkpointer"] = checkpointer
@@ -250,12 +277,11 @@ class DeerFlowClient:
self._agent_config_key = key
logger.info("Agent created: agent_name=%s, model=%s, thinking=%s", self._agent_name, model_name, thinking_enabled)
@staticmethod
def _get_tools(*, model_name: str | None, subagent_enabled: bool):
def _get_tools(self, *, model_name: str | None, subagent_enabled: bool):
"""Lazy import to avoid circular dependency at module level."""
from deerflow.tools import get_available_tools
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled)
return get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=self._app_config)
@staticmethod
def _serialize_tool_calls(tool_calls) -> list[dict]:
@@ -374,9 +400,9 @@ class DeerFlowClient:
"""
checkpointer = self._checkpointer
if checkpointer is None:
from deerflow.agents.checkpointer.provider import get_checkpointer
from deerflow.runtime.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer()
checkpointer = get_checkpointer(self._app_config)
thread_info_map = {}
@@ -429,9 +455,9 @@ class DeerFlowClient:
"""
checkpointer = self._checkpointer
if checkpointer is None:
from deerflow.agents.checkpointer.provider import get_checkpointer
from deerflow.runtime.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer()
checkpointer = get_checkpointer(self._app_config)
config = {"configurable": {"thread_id": thread_id}}
checkpoints = []
@@ -551,9 +577,7 @@ class DeerFlowClient:
self._ensure_agent(config)
state: dict[str, Any] = {"messages": [HumanMessage(content=message)]}
context = {"thread_id": thread_id}
if self._agent_name:
context["agent_name"] = self._agent_name
context = DeerFlowContext(app_config=self._app_config, thread_id=thread_id, agent_name=self._agent_name)
seen_ids: set[str] = set()
# Cross-mode handoff: ids already streamed via LangGraph ``messages``
@@ -762,7 +786,7 @@ class DeerFlowClient:
"category": s.category,
"enabled": s.enabled,
}
for s in load_skills(enabled_only=enabled_only)
for s in load_skills(self._app_config, enabled_only=enabled_only)
]
}
@@ -774,19 +798,19 @@ class DeerFlowClient:
"""
from deerflow.agents.memory.updater import get_memory_data
return get_memory_data()
return get_memory_data(self._app_config.memory, user_id=get_effective_user_id())
def export_memory(self) -> dict:
"""Export current memory data for backup or transfer."""
from deerflow.agents.memory.updater import get_memory_data
return get_memory_data()
return get_memory_data(self._app_config.memory, user_id=get_effective_user_id())
def import_memory(self, memory_data: dict) -> dict:
"""Import and persist full memory data."""
from deerflow.agents.memory.updater import import_memory_data
return import_memory_data(memory_data)
return import_memory_data(self._app_config.memory, memory_data, user_id=get_effective_user_id())
def get_model(self, name: str) -> dict | None:
"""Get a specific model's configuration by name.
@@ -821,8 +845,8 @@ class DeerFlowClient:
Dict with "mcp_servers" key mapping server name to config,
matching the Gateway API ``McpConfigResponse`` schema.
"""
config = get_extensions_config()
return {"mcp_servers": {name: server.model_dump() for name, server in config.mcp_servers.items()}}
ext = self._app_config.extensions
return {"mcp_servers": {name: server.model_dump() for name, server in ext.mcp_servers.items()}}
def update_mcp_config(self, mcp_servers: dict[str, dict]) -> dict:
"""Update MCP server configurations.
@@ -844,18 +868,19 @@ class DeerFlowClient:
if config_path is None:
raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.")
current_config = get_extensions_config()
current_ext = self._app_config.extensions
config_data = {
"mcpServers": mcp_servers,
"skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()},
"skills": {name: {"enabled": skill.enabled} for name, skill in current_ext.skills.items()},
}
self._atomic_write_json(config_path, config_data)
self._agent = None
self._agent_config_key = None
reloaded = reload_extensions_config()
self._reload_config()
reloaded = self._app_config.extensions
return {"mcp_servers": {name: server.model_dump() for name, server in reloaded.mcp_servers.items()}}
# ------------------------------------------------------------------
@@ -873,7 +898,7 @@ class DeerFlowClient:
"""
from deerflow.skills.loader import load_skills
skill = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
skill = next((s for s in load_skills(self._app_config, enabled_only=False) if s.name == name), None)
if skill is None:
return None
return {
@@ -900,7 +925,7 @@ class DeerFlowClient:
"""
from deerflow.skills.loader import load_skills
skills = load_skills(enabled_only=False)
skills = load_skills(self._app_config, enabled_only=False)
skill = next((s for s in skills if s.name == name), None)
if skill is None:
raise ValueError(f"Skill '{name}' not found")
@@ -909,21 +934,25 @@ class DeerFlowClient:
if config_path is None:
raise FileNotFoundError("Cannot locate extensions_config.json. Set DEER_FLOW_EXTENSIONS_CONFIG_PATH or ensure it exists in the project root.")
extensions_config = get_extensions_config()
extensions_config.skills[name] = SkillStateConfig(enabled=enabled)
# Do not mutate self._app_config (frozen value). Compose the new
# skills state in a fresh dict, write it to disk, and let _reload_config()
# below rebuild AppConfig from the updated file.
ext = self._app_config.extensions
new_skills = {n: {"enabled": sc.enabled} for n, sc in ext.skills.items()}
new_skills[name] = {"enabled": enabled}
config_data = {
"mcpServers": {n: s.model_dump() for n, s in extensions_config.mcp_servers.items()},
"skills": {n: {"enabled": sc.enabled} for n, sc in extensions_config.skills.items()},
"mcpServers": {n: s.model_dump() for n, s in ext.mcp_servers.items()},
"skills": new_skills,
}
self._atomic_write_json(config_path, config_data)
self._agent = None
self._agent_config_key = None
reload_extensions_config()
self._reload_config()
updated = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
updated = next((s for s in load_skills(self._app_config, enabled_only=False) if s.name == name), None)
if updated is None:
raise RuntimeError(f"Skill '{name}' disappeared after update")
return {
@@ -961,25 +990,25 @@ class DeerFlowClient:
"""
from deerflow.agents.memory.updater import reload_memory_data
return reload_memory_data()
return reload_memory_data(self._app_config.memory, user_id=get_effective_user_id())
def clear_memory(self) -> dict:
"""Clear all persisted memory data."""
from deerflow.agents.memory.updater import clear_memory_data
return clear_memory_data()
return clear_memory_data(self._app_config.memory, user_id=get_effective_user_id())
def create_memory_fact(self, content: str, category: str = "context", confidence: float = 0.5) -> dict:
"""Create a single fact manually."""
from deerflow.agents.memory.updater import create_memory_fact
return create_memory_fact(content=content, category=category, confidence=confidence)
return create_memory_fact(self._app_config.memory, content=content, category=category, confidence=confidence)
def delete_memory_fact(self, fact_id: str) -> dict:
"""Delete a single fact from memory by fact id."""
from deerflow.agents.memory.updater import delete_memory_fact
return delete_memory_fact(fact_id)
return delete_memory_fact(self._app_config.memory, fact_id)
def update_memory_fact(
self,
@@ -992,6 +1021,7 @@ class DeerFlowClient:
from deerflow.agents.memory.updater import update_memory_fact
return update_memory_fact(
self._app_config.memory,
fact_id=fact_id,
content=content,
category=category,
@@ -1004,9 +1034,7 @@ class DeerFlowClient:
Returns:
Memory config dict.
"""
from deerflow.config.memory_config import get_memory_config
config = get_memory_config()
config = self._app_config.memory
return {
"enabled": config.enabled,
"storage_path": config.storage_path,
@@ -1184,7 +1212,7 @@ class DeerFlowClient:
ValueError: If the path is invalid.
"""
try:
actual = get_paths().resolve_virtual_path(thread_id, path)
actual = get_paths().resolve_virtual_path(thread_id, path, user_id=get_effective_user_id())
except ValueError as exc:
if "traversal" in str(exc):
from deerflow.uploads.manager import PathTraversalError
@@ -25,8 +25,9 @@ except ImportError: # pragma: no cover - Windows fallback
fcntl = None # type: ignore[assignment]
import msvcrt
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import SandboxProvider
@@ -89,7 +90,8 @@ class AioSandboxProvider(SandboxProvider):
API_KEY: $MY_API_KEY
"""
def __init__(self):
def __init__(self, app_config: "AppConfig"):
self._app_config = app_config
self._lock = threading.Lock()
self._sandboxes: dict[str, AioSandbox] = {} # sandbox_id -> AioSandbox instance
self._sandbox_infos: dict[str, SandboxInfo] = {} # sandbox_id -> SandboxInfo (for destroy)
@@ -158,8 +160,7 @@ class AioSandboxProvider(SandboxProvider):
def _load_config(self) -> dict:
"""Load sandbox configuration from app config."""
config = get_app_config()
sandbox_config = config.sandbox
sandbox_config = self._app_config.sandbox
idle_timeout = getattr(sandbox_config, "idle_timeout", None)
replicas = getattr(sandbox_config, "replicas", None)
@@ -270,28 +271,27 @@ class AioSandboxProvider(SandboxProvider):
mounted Docker socket (DooD), the host Docker daemon can resolve the paths.
"""
paths = get_paths()
paths.ensure_thread_dirs(thread_id)
user_id = get_effective_user_id()
paths.ensure_thread_dirs(thread_id, user_id=user_id)
return [
(paths.host_sandbox_work_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/workspace", False),
(paths.host_sandbox_uploads_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/uploads", False),
(paths.host_sandbox_outputs_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/outputs", False),
(paths.host_sandbox_work_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/workspace", False),
(paths.host_sandbox_uploads_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/uploads", False),
(paths.host_sandbox_outputs_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/outputs", False),
# ACP workspace: read-only inside the sandbox (lead agent reads results;
# the ACP subprocess writes from the host side, not from within the container).
(paths.host_acp_workspace_dir(thread_id), "/mnt/acp-workspace", True),
(paths.host_acp_workspace_dir(thread_id, user_id=user_id), "/mnt/acp-workspace", True),
]
@staticmethod
def _get_skills_mount() -> tuple[str, str, bool] | None:
def _get_skills_mount(self) -> tuple[str, str, bool] | None:
"""Get the skills directory mount configuration.
Mount source uses DEER_FLOW_HOST_SKILLS_PATH when running inside Docker (DooD)
so the host Docker daemon can resolve the path.
"""
try:
config = get_app_config()
skills_path = config.skills.get_skills_path()
container_path = config.skills.container_path
skills_path = self._app_config.skills.get_skills_path()
container_path = self._app_config.skills.container_path
if skills_path.exists():
# When running inside Docker with DooD, use host-side skills path.
@@ -490,8 +490,9 @@ class AioSandboxProvider(SandboxProvider):
across multiple processes, preventing container-name conflicts.
"""
paths = get_paths()
paths.ensure_thread_dirs(thread_id)
lock_path = paths.thread_dir(thread_id) / f"{sandbox_id}.lock"
user_id = get_effective_user_id()
paths.ensure_thread_dirs(thread_id, user_id=user_id)
lock_path = paths.thread_dir(thread_id, user_id=user_id) / f"{sandbox_id}.lock"
with open(lock_path, "a", encoding="utf-8") as lock_file:
locked = False
@@ -5,9 +5,9 @@ Web Search Tool - Search the web using DuckDuckGo (no API key required).
import json
import logging
from langchain.tools import tool
from langchain.tools import ToolRuntime, tool
from deerflow.config import get_app_config
from deerflow.config.deer_flow_context import resolve_context
logger = logging.getLogger(__name__)
@@ -55,6 +55,7 @@ def _search_text(
@tool("web_search", parse_docstring=True)
def web_search_tool(
query: str,
runtime: ToolRuntime,
max_results: int = 5,
) -> str:
"""Search the web for information. Use this tool to find current information, news, articles, and facts from the internet.
@@ -63,11 +64,11 @@ def web_search_tool(
query: Search keywords describing what you want to find. Be specific for better results.
max_results: Maximum number of results to return. Default is 5.
"""
config = get_app_config().get_tool_config("web_search")
tool_config = resolve_context(runtime).app_config.get_tool_config("web_search")
# Override max_results from config if set
if config is not None and "max_results" in config.model_extra:
max_results = config.model_extra.get("max_results", max_results)
if tool_config is not None and "max_results" in tool_config.model_extra:
max_results = tool_config.model_extra.get("max_results", max_results)
results = _search_text(
query=query,
@@ -1,37 +1,39 @@
import json
from exa_py import Exa
from langchain.tools import tool
from langchain.tools import ToolRuntime, tool
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import resolve_context
def _get_exa_client(tool_name: str = "web_search") -> Exa:
config = get_app_config().get_tool_config(tool_name)
def _get_exa_client(app_config: AppConfig, tool_name: str = "web_search") -> Exa:
tool_config = app_config.get_tool_config(tool_name)
api_key = None
if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key")
if tool_config is not None and "api_key" in tool_config.model_extra:
api_key = tool_config.model_extra.get("api_key")
return Exa(api_key=api_key)
@tool("web_search", parse_docstring=True)
def web_search_tool(query: str) -> str:
def web_search_tool(query: str, runtime: ToolRuntime) -> str:
"""Search the web.
Args:
query: The query to search for.
"""
try:
config = get_app_config().get_tool_config("web_search")
app_config = resolve_context(runtime).app_config
tool_config = app_config.get_tool_config("web_search")
max_results = 5
search_type = "auto"
contents_max_characters = 1000
if config is not None:
max_results = config.model_extra.get("max_results", max_results)
search_type = config.model_extra.get("search_type", search_type)
contents_max_characters = config.model_extra.get("contents_max_characters", contents_max_characters)
if tool_config is not None:
max_results = tool_config.model_extra.get("max_results", max_results)
search_type = tool_config.model_extra.get("search_type", search_type)
contents_max_characters = tool_config.model_extra.get("contents_max_characters", contents_max_characters)
client = _get_exa_client()
client = _get_exa_client(app_config)
res = client.search(
query,
type=search_type,
@@ -54,7 +56,7 @@ def web_search_tool(query: str) -> str:
@tool("web_fetch", parse_docstring=True)
def web_fetch_tool(url: str) -> str:
def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
"""Fetch the contents of a web page at a given URL.
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
@@ -65,7 +67,7 @@ def web_fetch_tool(url: str) -> str:
url: The URL to fetch the contents of.
"""
try:
client = _get_exa_client("web_fetch")
client = _get_exa_client(resolve_context(runtime).app_config, "web_fetch")
res = client.get_contents([url], text={"max_characters": 4096})
if res.results:
@@ -1,33 +1,35 @@
import json
from firecrawl import FirecrawlApp
from langchain.tools import tool
from langchain.tools import ToolRuntime, tool
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import resolve_context
def _get_firecrawl_client(tool_name: str = "web_search") -> FirecrawlApp:
config = get_app_config().get_tool_config(tool_name)
def _get_firecrawl_client(app_config: AppConfig, tool_name: str = "web_search") -> FirecrawlApp:
tool_config = app_config.get_tool_config(tool_name)
api_key = None
if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key")
if tool_config is not None and "api_key" in tool_config.model_extra:
api_key = tool_config.model_extra.get("api_key")
return FirecrawlApp(api_key=api_key) # type: ignore[arg-type]
@tool("web_search", parse_docstring=True)
def web_search_tool(query: str) -> str:
def web_search_tool(query: str, runtime: ToolRuntime) -> str:
"""Search the web.
Args:
query: The query to search for.
"""
try:
config = get_app_config().get_tool_config("web_search")
app_config = resolve_context(runtime).app_config
tool_config = app_config.get_tool_config("web_search")
max_results = 5
if config is not None:
max_results = config.model_extra.get("max_results", max_results)
if tool_config is not None:
max_results = tool_config.model_extra.get("max_results", max_results)
client = _get_firecrawl_client("web_search")
client = _get_firecrawl_client(app_config, "web_search")
result = client.search(query, limit=max_results)
# result.web contains list of SearchResultWeb objects
@@ -47,7 +49,7 @@ def web_search_tool(query: str) -> str:
@tool("web_fetch", parse_docstring=True)
def web_fetch_tool(url: str) -> str:
def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
"""Fetch the contents of a web page at a given URL.
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
@@ -58,7 +60,8 @@ def web_fetch_tool(url: str) -> str:
url: The URL to fetch the contents of.
"""
try:
client = _get_firecrawl_client("web_fetch")
app_config = resolve_context(runtime).app_config
client = _get_firecrawl_client(app_config, "web_fetch")
result = client.scrape(url, formats=["markdown"])
markdown_content = result.markdown or ""
@@ -5,9 +5,9 @@ Image Search Tool - Search images using DuckDuckGo for reference in image genera
import json
import logging
from langchain.tools import tool
from langchain.tools import ToolRuntime, tool
from deerflow.config import get_app_config
from deerflow.config.deer_flow_context import resolve_context
logger = logging.getLogger(__name__)
@@ -77,6 +77,7 @@ def _search_images(
@tool("image_search", parse_docstring=True)
def image_search_tool(
query: str,
runtime: ToolRuntime,
max_results: int = 5,
size: str | None = None,
type_image: str | None = None,
@@ -99,11 +100,11 @@ def image_search_tool(
type_image: Image type filter. Options: "photo", "clipart", "gif", "transparent", "line". Use "photo" for realistic references.
layout: Layout filter. Options: "Square", "Tall", "Wide". Choose based on your generation needs.
"""
config = get_app_config().get_tool_config("image_search")
tool_config = resolve_context(runtime).app_config.get_tool_config("image_search")
# Override max_results from config if set
if config is not None and "max_results" in config.model_extra:
max_results = config.model_extra.get("max_results", max_results)
if tool_config is not None and "max_results" in tool_config.model_extra:
max_results = tool_config.model_extra.get("max_results", max_results)
results = _search_images(
query=query,
@@ -1,6 +1,7 @@
from langchain.tools import tool
from langchain.tools import ToolRuntime, tool
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import resolve_context
from deerflow.utils.readability import ReadabilityExtractor
from .infoquest_client import InfoQuestClient
@@ -8,13 +9,13 @@ from .infoquest_client import InfoQuestClient
readability_extractor = ReadabilityExtractor()
def _get_infoquest_client() -> InfoQuestClient:
search_config = get_app_config().get_tool_config("web_search")
def _get_infoquest_client(app_config: AppConfig) -> InfoQuestClient:
search_config = app_config.get_tool_config("web_search")
search_time_range = -1
if search_config is not None and "search_time_range" in search_config.model_extra:
search_time_range = search_config.model_extra.get("search_time_range")
fetch_config = get_app_config().get_tool_config("web_fetch")
fetch_config = app_config.get_tool_config("web_fetch")
fetch_time = -1
if fetch_config is not None and "fetch_time" in fetch_config.model_extra:
fetch_time = fetch_config.model_extra.get("fetch_time")
@@ -25,7 +26,7 @@ def _get_infoquest_client() -> InfoQuestClient:
if fetch_config is not None and "navigation_timeout" in fetch_config.model_extra:
navigation_timeout = fetch_config.model_extra.get("navigation_timeout")
image_search_config = get_app_config().get_tool_config("image_search")
image_search_config = app_config.get_tool_config("image_search")
image_search_time_range = -1
if image_search_config is not None and "image_search_time_range" in image_search_config.model_extra:
image_search_time_range = image_search_config.model_extra.get("image_search_time_range")
@@ -44,19 +45,18 @@ def _get_infoquest_client() -> InfoQuestClient:
@tool("web_search", parse_docstring=True)
def web_search_tool(query: str) -> str:
def web_search_tool(query: str, runtime: ToolRuntime) -> str:
"""Search the web.
Args:
query: The query to search for.
"""
client = _get_infoquest_client()
client = _get_infoquest_client(resolve_context(runtime).app_config)
return client.web_search(query)
@tool("web_fetch", parse_docstring=True)
def web_fetch_tool(url: str) -> str:
def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
"""Fetch the contents of a web page at a given URL.
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
@@ -66,7 +66,7 @@ def web_fetch_tool(url: str) -> str:
Args:
url: The URL to fetch the contents of.
"""
client = _get_infoquest_client()
client = _get_infoquest_client(resolve_context(runtime).app_config)
result = client.fetch(url)
if result.startswith("Error: "):
return result
@@ -75,7 +75,7 @@ def web_fetch_tool(url: str) -> str:
@tool("image_search", parse_docstring=True)
def image_search_tool(query: str) -> str:
def image_search_tool(query: str, runtime: ToolRuntime) -> str:
"""Search for images online. Use this tool BEFORE image generation to find reference images for characters, portraits, objects, scenes, or any content requiring visual accuracy.
**When to use:**
@@ -89,5 +89,5 @@ def image_search_tool(query: str) -> str:
Args:
query: The query to search for images.
"""
client = _get_infoquest_client()
client = _get_infoquest_client(resolve_context(runtime).app_config)
return client.image_search(query)
@@ -1,16 +1,16 @@
import asyncio
from langchain.tools import tool
from langchain.tools import ToolRuntime, tool
from deerflow.community.jina_ai.jina_client import JinaClient
from deerflow.config import get_app_config
from deerflow.config.deer_flow_context import resolve_context
from deerflow.utils.readability import ReadabilityExtractor
readability_extractor = ReadabilityExtractor()
@tool("web_fetch", parse_docstring=True)
async def web_fetch_tool(url: str) -> str:
async def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
"""Fetch the contents of a web page at a given URL.
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
@@ -22,9 +22,9 @@ async def web_fetch_tool(url: str) -> str:
"""
jina_client = JinaClient()
timeout = 10
config = get_app_config().get_tool_config("web_fetch")
if config is not None and "timeout" in config.model_extra:
timeout = config.model_extra.get("timeout")
tool_config = resolve_context(runtime).app_config.get_tool_config("web_fetch")
if tool_config is not None and "timeout" in tool_config.model_extra:
timeout = tool_config.model_extra.get("timeout")
html_content = await jina_client.crawl(url, return_format="html", timeout=timeout)
if isinstance(html_content, str) and html_content.startswith("Error:"):
return html_content
@@ -1,32 +1,34 @@
import json
from langchain.tools import tool
from langchain.tools import ToolRuntime, tool
from tavily import TavilyClient
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import resolve_context
def _get_tavily_client() -> TavilyClient:
config = get_app_config().get_tool_config("web_search")
def _get_tavily_client(app_config: AppConfig) -> TavilyClient:
tool_config = app_config.get_tool_config("web_search")
api_key = None
if config is not None and "api_key" in config.model_extra:
api_key = config.model_extra.get("api_key")
if tool_config is not None and "api_key" in tool_config.model_extra:
api_key = tool_config.model_extra.get("api_key")
return TavilyClient(api_key=api_key)
@tool("web_search", parse_docstring=True)
def web_search_tool(query: str) -> str:
def web_search_tool(query: str, runtime: ToolRuntime) -> str:
"""Search the web.
Args:
query: The query to search for.
"""
config = get_app_config().get_tool_config("web_search")
app_config = resolve_context(runtime).app_config
tool_config = app_config.get_tool_config("web_search")
max_results = 5
if config is not None and "max_results" in config.model_extra:
max_results = config.model_extra.get("max_results")
if tool_config is not None and "max_results" in tool_config.model_extra:
max_results = tool_config.model_extra.get("max_results")
client = _get_tavily_client()
client = _get_tavily_client(app_config)
res = client.search(query, max_results=max_results)
normalized_results = [
{
@@ -41,7 +43,7 @@ def web_search_tool(query: str) -> str:
@tool("web_fetch", parse_docstring=True)
def web_fetch_tool(url: str) -> str:
def web_fetch_tool(url: str, runtime: ToolRuntime) -> str:
"""Fetch the contents of a web page at a given URL.
Only fetch EXACT URLs that have been provided directly by the user or have been returned in results from the web_search and web_fetch tools.
This tool can NOT access content that requires authentication, such as private Google Docs or pages behind login walls.
@@ -51,7 +53,8 @@ def web_fetch_tool(url: str) -> str:
Args:
url: The URL to fetch the contents of.
"""
client = _get_tavily_client()
app_config = resolve_context(runtime).app_config
client = _get_tavily_client(app_config)
res = client.extract([url])
if "failed_results" in res and len(res["failed_results"]) > 0:
return f"Error: {res['failed_results'][0]['error']}"
@@ -1,6 +1,6 @@
from .app_config import get_app_config
from .extensions_config import ExtensionsConfig, get_extensions_config
from .memory_config import MemoryConfig, get_memory_config
from .app_config import AppConfig
from .extensions_config import ExtensionsConfig
from .memory_config import MemoryConfig
from .paths import Paths, get_paths
from .skill_evolution_config import SkillEvolutionConfig
from .skills_config import SkillsConfig
@@ -13,18 +13,16 @@ from .tracing_config import (
)
__all__ = [
"get_app_config",
"SkillEvolutionConfig",
"Paths",
"get_paths",
"SkillsConfig",
"AppConfig",
"ExtensionsConfig",
"get_extensions_config",
"MemoryConfig",
"get_memory_config",
"get_tracing_config",
"get_explicitly_enabled_tracing_providers",
"Paths",
"SkillEvolutionConfig",
"SkillsConfig",
"get_enabled_tracing_providers",
"get_explicitly_enabled_tracing_providers",
"get_paths",
"get_tracing_config",
"is_tracing_enabled",
"validate_enabled_tracing_providers",
]
@@ -1,16 +1,13 @@
"""ACP (Agent Client Protocol) agent configuration loaded from config.yaml."""
import logging
from collections.abc import Mapping
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
from pydantic import BaseModel, ConfigDict, Field
class ACPAgentConfig(BaseModel):
"""Configuration for a single ACP-compatible agent."""
model_config = ConfigDict(frozen=True)
command: str = Field(description="Command to launch the ACP agent subprocess")
args: list[str] = Field(default_factory=list, description="Additional command arguments")
env: dict[str, str] = Field(default_factory=dict, description="Environment variables to inject into the agent subprocess. Values starting with $ are resolved from host environment variables.")
@@ -24,28 +21,3 @@ class ACPAgentConfig(BaseModel):
"are denied — the agent must be configured to operate without requesting permissions."
),
)
_acp_agents: dict[str, ACPAgentConfig] = {}
def get_acp_agents() -> dict[str, ACPAgentConfig]:
"""Get the currently configured ACP agents.
Returns:
Mapping of agent name -> ACPAgentConfig. Empty dict if no ACP agents are configured.
"""
return _acp_agents
def load_acp_config_from_dict(config_dict: Mapping[str, Mapping[str, object]] | None) -> None:
"""Load ACP agent configuration from a dictionary (typically from config.yaml).
Args:
config_dict: Mapping of agent name -> config fields.
"""
global _acp_agents
if config_dict is None:
config_dict = {}
_acp_agents = {name: ACPAgentConfig(**cfg) for name, cfg in config_dict.items()}
logger.info("ACP config loaded: %d agent(s): %s", len(_acp_agents), list(_acp_agents.keys()))
@@ -1,32 +1,14 @@
"""Configuration for the custom agents management API."""
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class AgentsApiConfig(BaseModel):
"""Configuration for custom-agent and user-profile management routes."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(
default=False,
description=("Whether to expose the custom-agent management API over HTTP. When disabled, the gateway rejects read/write access to custom agent SOUL.md, config, and USER.md prompt-management routes."),
)
_agents_api_config: AgentsApiConfig = AgentsApiConfig()
def get_agents_api_config() -> AgentsApiConfig:
"""Get the current agents API configuration."""
return _agents_api_config
def set_agents_api_config(config: AgentsApiConfig) -> None:
"""Set the agents API configuration."""
global _agents_api_config
_agents_api_config = config
def load_agents_api_config_from_dict(config_dict: dict) -> None:
"""Load agents API configuration from a dictionary."""
global _agents_api_config
_agents_api_config = AgentsApiConfig(**config_dict)
@@ -5,7 +5,7 @@ import re
from typing import Any
import yaml
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from deerflow.config.paths import get_paths
@@ -29,6 +29,8 @@ def validate_agent_name(name: str | None) -> str | None:
class AgentConfig(BaseModel):
"""Configuration for a custom agent."""
model_config = ConfigDict(frozen=True)
name: str
description: str = ""
model: str | None = None
@@ -1,6 +1,7 @@
from __future__ import annotations
import logging
import os
from contextvars import ContextVar
from pathlib import Path
from typing import Any, Self
@@ -8,23 +9,25 @@ import yaml
from dotenv import load_dotenv
from pydantic import BaseModel, ConfigDict, Field
from deerflow.config.acp_config import load_acp_config_from_dict
from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
from deerflow.config.acp_config import ACPAgentConfig
from deerflow.config.agents_api_config import AgentsApiConfig
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.config.database_config import DatabaseConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
from deerflow.config.guardrails_config import GuardrailsConfig
from deerflow.config.memory_config import MemoryConfig
from deerflow.config.model_config import ModelConfig
from deerflow.config.run_events_config import RunEventsConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.skill_evolution_config import SkillEvolutionConfig
from deerflow.config.skills_config import SkillsConfig
from deerflow.config.stream_bridge_config import StreamBridgeConfig, load_stream_bridge_config_from_dict
from deerflow.config.subagents_config import SubagentsAppConfig, load_subagents_config_from_dict
from deerflow.config.summarization_config import SummarizationConfig, load_summarization_config_from_dict
from deerflow.config.title_config import TitleConfig, load_title_config_from_dict
from deerflow.config.stream_bridge_config import StreamBridgeConfig
from deerflow.config.subagents_config import SubagentsAppConfig
from deerflow.config.summarization_config import SummarizationConfig
from deerflow.config.title_config import TitleConfig
from deerflow.config.token_usage_config import TokenUsageConfig
from deerflow.config.tool_config import ToolConfig, ToolGroupConfig
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
from deerflow.config.tool_search_config import ToolSearchConfig
load_dotenv()
@@ -65,9 +68,12 @@ class AppConfig(BaseModel):
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
model_config = ConfigDict(extra="allow", frozen=False)
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
model_config = ConfigDict(extra="allow", frozen=True)
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
acp_agents: dict[str, ACPAgentConfig] = Field(default_factory=dict, description="ACP agent configurations keyed by agent name")
@classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path:
@@ -115,49 +121,6 @@ class AppConfig(BaseModel):
config_data = cls.resolve_env_variables(config_data)
# Load title config if present
if "title" in config_data:
load_title_config_from_dict(config_data["title"])
# Load summarization config if present
if "summarization" in config_data:
load_summarization_config_from_dict(config_data["summarization"])
# Load memory config if present
if "memory" in config_data:
load_memory_config_from_dict(config_data["memory"])
# Always refresh agents API config so removed config sections reset
# singleton-backed state to its default/disabled values on reload.
load_agents_api_config_from_dict(config_data.get("agents_api") or {})
# Load subagents config if present
if "subagents" in config_data:
load_subagents_config_from_dict(config_data["subagents"])
# Load tool_search config if present
if "tool_search" in config_data:
load_tool_search_config_from_dict(config_data["tool_search"])
# Load guardrails config if present
if "guardrails" in config_data:
load_guardrails_config_from_dict(config_data["guardrails"])
# Load circuit_breaker config if present
if "circuit_breaker" in config_data:
config_data["circuit_breaker"] = config_data["circuit_breaker"]
# Load checkpointer config if present
if "checkpointer" in config_data:
load_checkpointer_config_from_dict(config_data["checkpointer"])
# Load stream bridge config if present
if "stream_bridge" in config_data:
load_stream_bridge_config_from_dict(config_data["stream_bridge"])
# Always refresh ACP agent config so removed entries do not linger across reloads.
load_acp_config_from_dict(config_data.get("acp_agents", {}))
# Load extensions config separately (it's in a different file)
extensions_config = ExtensionsConfig.from_file()
config_data["extensions"] = extensions_config.model_dump()
@@ -268,130 +231,8 @@ class AppConfig(BaseModel):
"""
return next((group for group in self.tool_groups if group.name == name), None)
_app_config: AppConfig | None = None
_app_config_path: Path | None = None
_app_config_mtime: float | None = None
_app_config_is_custom = False
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
def _get_config_mtime(config_path: Path) -> float | None:
"""Get the modification time of a config file if it exists."""
try:
return config_path.stat().st_mtime
except OSError:
return None
def _load_and_cache_app_config(config_path: str | None = None) -> AppConfig:
"""Load config from disk and refresh cache metadata."""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
resolved_path = AppConfig.resolve_config_path(config_path)
_app_config = AppConfig.from_file(str(resolved_path))
_app_config_path = resolved_path
_app_config_mtime = _get_config_mtime(resolved_path)
_app_config_is_custom = False
return _app_config
def get_app_config() -> AppConfig:
"""Get the DeerFlow config instance.
Returns a cached singleton instance and automatically reloads it when the
underlying config file path or modification time changes. Use
`reload_app_config()` to force a reload, or `reset_app_config()` to clear
the cache.
"""
global _app_config, _app_config_path, _app_config_mtime
runtime_override = _current_app_config.get()
if runtime_override is not None:
return runtime_override
if _app_config is not None and _app_config_is_custom:
return _app_config
resolved_path = AppConfig.resolve_config_path()
current_mtime = _get_config_mtime(resolved_path)
should_reload = _app_config is None or _app_config_path != resolved_path or _app_config_mtime != current_mtime
if should_reload:
if _app_config_path == resolved_path and _app_config_mtime is not None and current_mtime is not None and _app_config_mtime != current_mtime:
logger.info(
"Config file has been modified (mtime: %s -> %s), reloading AppConfig",
_app_config_mtime,
current_mtime,
)
_load_and_cache_app_config(str(resolved_path))
return _app_config
def reload_app_config(config_path: str | None = None) -> AppConfig:
"""Reload the config from file and update the cached instance.
This is useful when the config file has been modified and you want
to pick up the changes without restarting the application.
Args:
config_path: Optional path to config file. If not provided,
uses the default resolution strategy.
Returns:
The newly loaded AppConfig instance.
"""
return _load_and_cache_app_config(config_path)
def reset_app_config() -> None:
"""Reset the cached config instance.
This clears the singleton cache, causing the next call to
`get_app_config()` to reload from file. Useful for testing
or when switching between different configurations.
"""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = None
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = False
def set_app_config(config: AppConfig) -> None:
"""Set a custom config instance.
This allows injecting a custom or mock config for testing purposes.
Args:
config: The AppConfig instance to use.
"""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = config
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = True
def peek_current_app_config() -> AppConfig | None:
"""Return the runtime-scoped AppConfig override, if one is active."""
return _current_app_config.get()
def push_current_app_config(config: AppConfig) -> None:
"""Push a runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
_current_app_config_stack.set(stack + (_current_app_config.get(),))
_current_app_config.set(config)
def pop_current_app_config() -> None:
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
if not stack:
_current_app_config.set(None)
return
previous = stack[-1]
_current_app_config_stack.set(stack[:-1])
_current_app_config.set(previous)
# AppConfig is a pure value object: construct with ``from_file()``, pass around.
# Composition roots that hold the resolved instance:
# - Gateway: ``app.state.config`` via ``Depends(get_config)``
# - Client: ``DeerFlowClient._app_config``
# - Agent run: ``Runtime[DeerFlowContext].context.app_config``
@@ -2,7 +2,7 @@
from typing import Literal
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
CheckpointerType = Literal["memory", "sqlite", "postgres"]
@@ -10,6 +10,8 @@ CheckpointerType = Literal["memory", "sqlite", "postgres"]
class CheckpointerConfig(BaseModel):
"""Configuration for LangGraph state persistence checkpointer."""
model_config = ConfigDict(frozen=True)
type: CheckpointerType = Field(
description="Checkpointer backend type. "
"'memory' is in-process only (lost on restart). "
@@ -23,24 +25,3 @@ class CheckpointerConfig(BaseModel):
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
)
# Global configuration instance — None means no checkpointer is configured.
_checkpointer_config: CheckpointerConfig | None = None
def get_checkpointer_config() -> CheckpointerConfig | None:
"""Get the current checkpointer configuration, or None if not configured."""
return _checkpointer_config
def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
"""Set the checkpointer configuration."""
global _checkpointer_config
_checkpointer_config = config
def load_checkpointer_config_from_dict(config_dict: dict) -> None:
"""Load checkpointer configuration from a dictionary."""
global _checkpointer_config
_checkpointer_config = CheckpointerConfig(**config_dict)
@@ -0,0 +1,103 @@
"""Unified database backend configuration.
Controls BOTH the LangGraph checkpointer and the DeerFlow application
persistence layer (runs, threads metadata, users, etc.). The user
configures one backend; the system handles physical separation details.
SQLite mode: checkpointer and app share a single .db file
({sqlite_dir}/deerflow.db) with WAL journal mode enabled on every
connection. WAL allows concurrent readers and a single writer without
blocking, making a unified file safe for both workloads. Writers
that contend for the lock wait via the default 5-second sqlite3
busy timeout rather than failing immediately.
Postgres mode: both use the same database URL but maintain independent
connection pools with different lifecycles.
Memory mode: checkpointer uses MemorySaver, app uses in-memory stores.
No database is initialized.
Sensitive values (postgres_url) should use $VAR syntax in config.yaml
to reference environment variables from .env:
database:
backend: postgres
postgres_url: $DATABASE_URL
The $VAR resolution is handled by AppConfig.resolve_env_variables()
before this config is instantiated -- DatabaseConfig itself does not
need to do any environment variable processing.
"""
from __future__ import annotations
import os
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field
class DatabaseConfig(BaseModel):
model_config = ConfigDict(frozen=True)
backend: Literal["memory", "sqlite", "postgres"] = Field(
default="memory",
description=("Storage backend for both checkpointer and application data. 'memory' for development (no persistence across restarts), 'sqlite' for single-node deployment, 'postgres' for production multi-node deployment."),
)
sqlite_dir: str = Field(
default=".deer-flow/data",
description=("Directory for the SQLite database file. Both checkpointer and application data share {sqlite_dir}/deerflow.db."),
)
postgres_url: str = Field(
default="",
description=(
"PostgreSQL connection URL, shared by checkpointer and app. "
"Use $DATABASE_URL in config.yaml to reference .env. "
"Example: postgresql://user:pass@host:5432/deerflow "
"(the +asyncpg driver suffix is added automatically where needed)."
),
)
echo_sql: bool = Field(
default=False,
description="Echo all SQL statements to log (debug only).",
)
pool_size: int = Field(
default=5,
description="Connection pool size for the app ORM engine (postgres only).",
)
# -- Derived helpers (not user-configured) --
@property
def _resolved_sqlite_dir(self) -> str:
"""Resolve sqlite_dir to an absolute path (relative to CWD)."""
from pathlib import Path
return str(Path(self.sqlite_dir).resolve())
@property
def sqlite_path(self) -> str:
"""Unified SQLite file path shared by checkpointer and app."""
return os.path.join(self._resolved_sqlite_dir, "deerflow.db")
# Backward-compatible aliases
@property
def checkpointer_sqlite_path(self) -> str:
"""SQLite file path for the LangGraph checkpointer (alias for sqlite_path)."""
return self.sqlite_path
@property
def app_sqlite_path(self) -> str:
"""SQLite file path for application ORM data (alias for sqlite_path)."""
return self.sqlite_path
@property
def app_sqlalchemy_url(self) -> str:
"""SQLAlchemy async URL for the application ORM engine."""
if self.backend == "sqlite":
return f"sqlite+aiosqlite:///{self.sqlite_path}"
if self.backend == "postgres":
url = self.postgres_url
if url.startswith("postgresql://"):
url = url.replace("postgresql://", "postgresql+asyncpg://", 1)
return url
raise ValueError(f"No SQLAlchemy URL for backend={self.backend!r}")
@@ -0,0 +1,55 @@
"""Per-invocation context for DeerFlow agent execution.
Injected via LangGraph Runtime. Middleware and tools access this
via Runtime[DeerFlowContext] parameters, through resolve_context().
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class DeerFlowContext:
"""Typed, immutable, per-invocation context injected via LangGraph Runtime.
Fields are all known at run start and never change during execution.
Mutable runtime state (e.g. sandbox_id) flows through ThreadState, not here.
"""
app_config: AppConfig
thread_id: str
agent_name: str | None = None
def resolve_context(runtime: Any) -> DeerFlowContext:
"""Return the typed DeerFlowContext that the runtime carries.
Gateway mode (``DeerFlowClient``, ``run_agent``) always attaches a typed
``DeerFlowContext`` via ``agent.astream(context=...)``; the LangGraph
Server path uses ``langgraph.json`` registration where the top-level
``make_lead_agent`` loads ``AppConfig`` from disk itself, so we still
arrive here with a typed context.
Only the dict/None shapes that legacy tests used to exercise would fall
through this function; we now reject them loudly instead of papering
over the missing context with an ambient ``AppConfig`` lookup.
"""
ctx = getattr(runtime, "context", None)
if isinstance(ctx, DeerFlowContext):
return ctx
raise RuntimeError(
"resolve_context: runtime.context is not a DeerFlowContext "
"(got type %s). Every entry point must attach one at invoke time — "
"Gateway/Client via agent.astream(context=DeerFlowContext(...)), "
"LangGraph Server via the make_lead_agent boundary that loads "
"AppConfig.from_file()." % type(ctx).__name__
)
@@ -11,6 +11,8 @@ from pydantic import BaseModel, ConfigDict, Field
class McpOAuthConfig(BaseModel):
"""OAuth configuration for an MCP server (HTTP/SSE transports)."""
model_config = ConfigDict(extra="allow", frozen=True)
enabled: bool = Field(default=True, description="Whether OAuth token injection is enabled")
token_url: str = Field(description="OAuth token endpoint URL")
grant_type: Literal["client_credentials", "refresh_token"] = Field(
@@ -28,12 +30,13 @@ class McpOAuthConfig(BaseModel):
default_token_type: str = Field(default="Bearer", description="Default token type when missing in token response")
refresh_skew_seconds: int = Field(default=60, description="Refresh token this many seconds before expiry")
extra_token_params: dict[str, str] = Field(default_factory=dict, description="Additional form params sent to token endpoint")
model_config = ConfigDict(extra="allow")
class McpServerConfig(BaseModel):
"""Configuration for a single MCP server."""
model_config = ConfigDict(extra="allow", frozen=True)
enabled: bool = Field(default=True, description="Whether this MCP server is enabled")
type: str = Field(default="stdio", description="Transport type: 'stdio', 'sse', or 'http'")
command: str | None = Field(default=None, description="Command to execute to start the MCP server (for stdio type)")
@@ -43,12 +46,13 @@ class McpServerConfig(BaseModel):
headers: dict[str, str] = Field(default_factory=dict, description="HTTP headers to send (for sse or http type)")
oauth: McpOAuthConfig | None = Field(default=None, description="OAuth configuration (for sse or http type)")
description: str = Field(default="", description="Human-readable description of what this MCP server provides")
model_config = ConfigDict(extra="allow")
class SkillStateConfig(BaseModel):
"""Configuration for a single skill's state."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(default=True, description="Whether this skill is enabled")
@@ -64,7 +68,7 @@ class ExtensionsConfig(BaseModel):
default_factory=dict,
description="Map of skill name to state configuration",
)
model_config = ConfigDict(extra="allow", populate_by_name=True)
model_config = ConfigDict(extra="allow", frozen=True, populate_by_name=True)
@classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path | None:
@@ -195,62 +199,3 @@ class ExtensionsConfig(BaseModel):
# Default to enable for public & custom skill
return skill_category in ("public", "custom")
return skill_config.enabled
_extensions_config: ExtensionsConfig | None = None
def get_extensions_config() -> ExtensionsConfig:
"""Get the extensions config instance.
Returns a cached singleton instance. Use `reload_extensions_config()` to reload
from file, or `reset_extensions_config()` to clear the cache.
Returns:
The cached ExtensionsConfig instance.
"""
global _extensions_config
if _extensions_config is None:
_extensions_config = ExtensionsConfig.from_file()
return _extensions_config
def reload_extensions_config(config_path: str | None = None) -> ExtensionsConfig:
"""Reload the extensions config from file and update the cached instance.
This is useful when the config file has been modified and you want
to pick up the changes without restarting the application.
Args:
config_path: Optional path to extensions config file. If not provided,
uses the default resolution strategy.
Returns:
The newly loaded ExtensionsConfig instance.
"""
global _extensions_config
_extensions_config = ExtensionsConfig.from_file(config_path)
return _extensions_config
def reset_extensions_config() -> None:
"""Reset the cached extensions config instance.
This clears the singleton cache, causing the next call to
`get_extensions_config()` to reload from file. Useful for testing
or when switching between different configurations.
"""
global _extensions_config
_extensions_config = None
def set_extensions_config(config: ExtensionsConfig) -> None:
"""Set a custom extensions config instance.
This allows injecting a custom or mock config for testing purposes.
Args:
config: The ExtensionsConfig instance to use.
"""
global _extensions_config
_extensions_config = config
@@ -1,11 +1,13 @@
"""Configuration for pre-tool-call authorization."""
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class GuardrailProviderConfig(BaseModel):
"""Configuration for a guardrail provider."""
model_config = ConfigDict(frozen=True)
use: str = Field(description="Class path (e.g. 'deerflow.guardrails.builtin:AllowlistProvider')")
config: dict = Field(default_factory=dict, description="Provider-specific settings passed as kwargs")
@@ -18,31 +20,9 @@ class GuardrailsConfig(BaseModel):
agent's passport reference, and returns an allow/deny decision.
"""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(default=False, description="Enable guardrail middleware")
fail_closed: bool = Field(default=True, description="Block tool calls if provider errors")
passport: str | None = Field(default=None, description="OAP passport path or hosted agent ID")
provider: GuardrailProviderConfig | None = Field(default=None, description="Guardrail provider configuration")
_guardrails_config: GuardrailsConfig | None = None
def get_guardrails_config() -> GuardrailsConfig:
"""Get the guardrails config, returning defaults if not loaded."""
global _guardrails_config
if _guardrails_config is None:
_guardrails_config = GuardrailsConfig()
return _guardrails_config
def load_guardrails_config_from_dict(data: dict) -> GuardrailsConfig:
"""Load guardrails config from a dict (called during AppConfig loading)."""
global _guardrails_config
_guardrails_config = GuardrailsConfig.model_validate(data)
return _guardrails_config
def reset_guardrails_config() -> None:
"""Reset the cached config instance. Used in tests to prevent singleton leaks."""
global _guardrails_config
_guardrails_config = None
@@ -1,11 +1,13 @@
"""Configuration for memory mechanism."""
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class MemoryConfig(BaseModel):
"""Configuration for global memory mechanism."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(
default=True,
description="Whether to enable memory mechanism",
@@ -14,8 +16,9 @@ class MemoryConfig(BaseModel):
default="",
description=(
"Path to store memory data. "
"If empty, defaults to `{base_dir}/memory.json` (see Paths.memory_file). "
"Absolute paths are used as-is. "
"If empty, defaults to per-user memory at `{base_dir}/users/{user_id}/memory.json`. "
"Absolute paths are used as-is and opt out of per-user isolation "
"(all users share the same file). "
"Relative paths are resolved against `Paths.base_dir` "
"(not the backend working directory). "
"Note: if you previously set this to `.deer-flow/memory.json`, "
@@ -59,24 +62,3 @@ class MemoryConfig(BaseModel):
le=8000,
description="Maximum tokens to use for memory injection",
)
# Global configuration instance
_memory_config: MemoryConfig = MemoryConfig()
def get_memory_config() -> MemoryConfig:
"""Get the current memory configuration."""
return _memory_config
def set_memory_config(config: MemoryConfig) -> None:
"""Set the memory configuration."""
global _memory_config
_memory_config = config
def load_memory_config_from_dict(config_dict: dict) -> None:
"""Load memory configuration from a dictionary."""
global _memory_config
_memory_config = MemoryConfig(**config_dict)
@@ -12,7 +12,7 @@ class ModelConfig(BaseModel):
description="Class path of the model provider(e.g. langchain_openai.ChatOpenAI)",
)
model: str = Field(..., description="Model name")
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra="allow", frozen=True)
use_responses_api: bool | None = Field(
default=None,
description="Whether to route OpenAI ChatOpenAI calls through the /v1/responses API",
@@ -7,6 +7,7 @@ from pathlib import Path, PureWindowsPath
VIRTUAL_PATH_PREFIX = "/mnt/user-data"
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
_SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
def _default_local_base_dir() -> Path:
@@ -22,6 +23,13 @@ def _validate_thread_id(thread_id: str) -> str:
return thread_id
def _validate_user_id(user_id: str) -> str:
"""Validate a user ID before using it in filesystem paths."""
if not _SAFE_USER_ID_RE.match(user_id):
raise ValueError(f"Invalid user_id {user_id!r}: only alphanumeric characters, hyphens, and underscores are allowed.")
return user_id
def _join_host_path(base: str, *parts: str) -> str:
"""Join host filesystem path segments while preserving native style.
@@ -134,44 +142,63 @@ class Paths:
"""Per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
return self.agent_dir(name) / "memory.json"
def thread_dir(self, thread_id: str) -> Path:
def user_dir(self, user_id: str) -> Path:
"""Directory for a specific user: `{base_dir}/users/{user_id}/`."""
return self.base_dir / "users" / _validate_user_id(user_id)
def user_memory_file(self, user_id: str) -> Path:
"""Per-user memory file: `{base_dir}/users/{user_id}/memory.json`."""
return self.user_dir(user_id) / "memory.json"
def user_agent_memory_file(self, user_id: str, agent_name: str) -> Path:
"""Per-user per-agent memory: `{base_dir}/users/{user_id}/agents/{name}/memory.json`."""
return self.user_dir(user_id) / "agents" / agent_name.lower() / "memory.json"
def thread_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
"""
Host path for a thread's data: `{base_dir}/threads/{thread_id}/`
Host path for a thread's data.
When *user_id* is provided:
`{base_dir}/users/{user_id}/threads/{thread_id}/`
Otherwise (legacy layout):
`{base_dir}/threads/{thread_id}/`
This directory contains a `user-data/` subdirectory that is mounted
as `/mnt/user-data/` inside the sandbox.
Raises:
ValueError: If `thread_id` contains unsafe characters (path separators
or `..`) that could cause directory traversal.
ValueError: If `thread_id` or `user_id` contains unsafe characters (path
separators or `..`) that could cause directory traversal.
"""
if user_id is not None:
return self.user_dir(user_id) / "threads" / _validate_thread_id(thread_id)
return self.base_dir / "threads" / _validate_thread_id(thread_id)
def sandbox_work_dir(self, thread_id: str) -> Path:
def sandbox_work_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
"""
Host path for the agent's workspace directory.
Host: `{base_dir}/threads/{thread_id}/user-data/workspace/`
Sandbox: `/mnt/user-data/workspace/`
"""
return self.thread_dir(thread_id) / "user-data" / "workspace"
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "workspace"
def sandbox_uploads_dir(self, thread_id: str) -> Path:
def sandbox_uploads_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
"""
Host path for user-uploaded files.
Host: `{base_dir}/threads/{thread_id}/user-data/uploads/`
Sandbox: `/mnt/user-data/uploads/`
"""
return self.thread_dir(thread_id) / "user-data" / "uploads"
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "uploads"
def sandbox_outputs_dir(self, thread_id: str) -> Path:
def sandbox_outputs_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
"""
Host path for agent-generated artifacts.
Host: `{base_dir}/threads/{thread_id}/user-data/outputs/`
Sandbox: `/mnt/user-data/outputs/`
"""
return self.thread_dir(thread_id) / "user-data" / "outputs"
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "outputs"
def acp_workspace_dir(self, thread_id: str) -> Path:
def acp_workspace_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
"""
Host path for the ACP workspace of a specific thread.
Host: `{base_dir}/threads/{thread_id}/acp-workspace/`
@@ -180,41 +207,43 @@ class Paths:
Each thread gets its own isolated ACP workspace so that concurrent
sessions cannot read each other's ACP agent outputs.
"""
return self.thread_dir(thread_id) / "acp-workspace"
return self.thread_dir(thread_id, user_id=user_id) / "acp-workspace"
def sandbox_user_data_dir(self, thread_id: str) -> Path:
def sandbox_user_data_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
"""
Host path for the user-data root.
Host: `{base_dir}/threads/{thread_id}/user-data/`
Sandbox: `/mnt/user-data/`
"""
return self.thread_dir(thread_id) / "user-data"
return self.thread_dir(thread_id, user_id=user_id) / "user-data"
def host_thread_dir(self, thread_id: str) -> str:
def host_thread_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
"""Host path for a thread directory, preserving Windows path syntax."""
if user_id is not None:
return _join_host_path(self._host_base_dir_str(), "users", _validate_user_id(user_id), "threads", _validate_thread_id(thread_id))
return _join_host_path(self._host_base_dir_str(), "threads", _validate_thread_id(thread_id))
def host_sandbox_user_data_dir(self, thread_id: str) -> str:
def host_sandbox_user_data_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
"""Host path for a thread's user-data root."""
return _join_host_path(self.host_thread_dir(thread_id), "user-data")
return _join_host_path(self.host_thread_dir(thread_id, user_id=user_id), "user-data")
def host_sandbox_work_dir(self, thread_id: str) -> str:
def host_sandbox_work_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
"""Host path for the workspace mount source."""
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "workspace")
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "workspace")
def host_sandbox_uploads_dir(self, thread_id: str) -> str:
def host_sandbox_uploads_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
"""Host path for the uploads mount source."""
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "uploads")
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "uploads")
def host_sandbox_outputs_dir(self, thread_id: str) -> str:
def host_sandbox_outputs_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
"""Host path for the outputs mount source."""
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "outputs")
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "outputs")
def host_acp_workspace_dir(self, thread_id: str) -> str:
def host_acp_workspace_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
"""Host path for the ACP workspace mount source."""
return _join_host_path(self.host_thread_dir(thread_id), "acp-workspace")
return _join_host_path(self.host_thread_dir(thread_id, user_id=user_id), "acp-workspace")
def ensure_thread_dirs(self, thread_id: str) -> None:
def ensure_thread_dirs(self, thread_id: str, *, user_id: str | None = None) -> None:
"""Create all standard sandbox directories for a thread.
Directories are created with mode 0o777 so that sandbox containers
@@ -228,24 +257,24 @@ class Paths:
ACP agent invocation.
"""
for d in [
self.sandbox_work_dir(thread_id),
self.sandbox_uploads_dir(thread_id),
self.sandbox_outputs_dir(thread_id),
self.acp_workspace_dir(thread_id),
self.sandbox_work_dir(thread_id, user_id=user_id),
self.sandbox_uploads_dir(thread_id, user_id=user_id),
self.sandbox_outputs_dir(thread_id, user_id=user_id),
self.acp_workspace_dir(thread_id, user_id=user_id),
]:
d.mkdir(parents=True, exist_ok=True)
d.chmod(0o777)
def delete_thread_dir(self, thread_id: str) -> None:
def delete_thread_dir(self, thread_id: str, *, user_id: str | None = None) -> None:
"""Delete all persisted data for a thread.
The operation is idempotent: missing thread directories are ignored.
"""
thread_dir = self.thread_dir(thread_id)
thread_dir = self.thread_dir(thread_id, user_id=user_id)
if thread_dir.exists():
shutil.rmtree(thread_dir)
def resolve_virtual_path(self, thread_id: str, virtual_path: str) -> Path:
def resolve_virtual_path(self, thread_id: str, virtual_path: str, *, user_id: str | None = None) -> Path:
"""Resolve a sandbox virtual path to the actual host filesystem path.
Args:
@@ -253,6 +282,7 @@ class Paths:
virtual_path: Virtual path as seen inside the sandbox, e.g.
``/mnt/user-data/outputs/report.pdf``.
Leading slashes are stripped before matching.
user_id: Optional user ID for user-scoped path resolution.
Returns:
The resolved absolute host filesystem path.
@@ -270,7 +300,7 @@ class Paths:
raise ValueError(f"Path must start with /{prefix}")
relative = stripped[len(prefix) :].lstrip("/")
base = self.sandbox_user_data_dir(thread_id).resolve()
base = self.sandbox_user_data_dir(thread_id, user_id=user_id).resolve()
actual = (base / relative).resolve()
try:
@@ -0,0 +1,34 @@
"""Run event storage configuration.
Controls where run events (messages + execution traces) are persisted.
Backends:
- memory: In-memory storage, data lost on restart. Suitable for
development and testing.
- db: SQL database via SQLAlchemy ORM. Provides full query capability.
Suitable for production deployments.
- jsonl: Append-only JSONL files. Lightweight alternative for
single-node deployments that need persistence without a database.
"""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field
class RunEventsConfig(BaseModel):
model_config = ConfigDict(frozen=True)
backend: Literal["memory", "db", "jsonl"] = Field(
default="memory",
description="Storage backend for run events. 'memory' for development (no persistence), 'db' for production (SQL queries), 'jsonl' for lightweight single-node persistence.",
)
max_trace_content: int = Field(
default=10240,
description="Maximum trace content size in bytes before truncation (db backend only).",
)
track_token_usage: bool = Field(
default=True,
description="Whether RunJournal should accumulate token counts to RunRow.",
)
@@ -4,6 +4,8 @@ from pydantic import BaseModel, ConfigDict, Field
class VolumeMountConfig(BaseModel):
"""Configuration for a volume mount."""
model_config = ConfigDict(frozen=True)
host_path: str = Field(..., description="Path on the host machine")
container_path: str = Field(..., description="Path inside the container")
read_only: bool = Field(default=False, description="Whether the mount is read-only")
@@ -80,4 +82,4 @@ class SandboxConfig(BaseModel):
description="Maximum characters to keep from ls tool output. Output exceeding this limit is head-truncated. Set to 0 to disable truncation.",
)
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra="allow", frozen=True)
@@ -1,9 +1,11 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class SkillEvolutionConfig(BaseModel):
"""Configuration for agent-managed skill evolution."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(
default=False,
description="Whether the agent can create and modify skills under skills/custom.",
@@ -1,6 +1,6 @@
from pathlib import Path
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
def _default_repo_root() -> Path:
@@ -11,6 +11,8 @@ def _default_repo_root() -> Path:
class SkillsConfig(BaseModel):
"""Configuration for skills system"""
model_config = ConfigDict(frozen=True)
path: str | None = Field(
default=None,
description="Path to skills directory. If not specified, defaults to ../skills relative to backend directory",
@@ -2,7 +2,7 @@
from typing import Literal
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
StreamBridgeType = Literal["memory", "redis"]
@@ -10,6 +10,8 @@ StreamBridgeType = Literal["memory", "redis"]
class StreamBridgeConfig(BaseModel):
"""Configuration for the stream bridge that connects agent workers to SSE endpoints."""
model_config = ConfigDict(frozen=True)
type: StreamBridgeType = Field(
default="memory",
description="Stream bridge backend type. 'memory' uses in-process asyncio.Queue (single-process only). 'redis' uses Redis Streams (planned for Phase 2, not yet implemented).",
@@ -22,25 +24,3 @@ class StreamBridgeConfig(BaseModel):
default=256,
description="Maximum number of events buffered per run in the memory bridge.",
)
# Global configuration instance — None means no stream bridge is configured
# (falls back to memory with defaults).
_stream_bridge_config: StreamBridgeConfig | None = None
def get_stream_bridge_config() -> StreamBridgeConfig | None:
"""Get the current stream bridge configuration, or None if not configured."""
return _stream_bridge_config
def set_stream_bridge_config(config: StreamBridgeConfig | None) -> None:
"""Set the stream bridge configuration."""
global _stream_bridge_config
_stream_bridge_config = config
def load_stream_bridge_config_from_dict(config_dict: dict) -> None:
"""Load stream bridge configuration from a dictionary."""
global _stream_bridge_config
_stream_bridge_config = StreamBridgeConfig(**config_dict)
@@ -1,15 +1,13 @@
"""Configuration for the subagent system loaded from config.yaml."""
import logging
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
from pydantic import BaseModel, ConfigDict, Field
class SubagentOverrideConfig(BaseModel):
"""Per-agent configuration overrides."""
model_config = ConfigDict(frozen=True)
timeout_seconds: int | None = Field(
default=None,
ge=1,
@@ -71,6 +69,8 @@ class CustomSubagentConfig(BaseModel):
class SubagentsAppConfig(BaseModel):
"""Configuration for the subagent system."""
model_config = ConfigDict(frozen=True)
timeout_seconds: int = Field(
default=900,
ge=1,
@@ -140,48 +140,3 @@ class SubagentsAppConfig(BaseModel):
if override is not None and override.skills is not None:
return override.skills
return None
_subagents_config: SubagentsAppConfig = SubagentsAppConfig()
def get_subagents_app_config() -> SubagentsAppConfig:
"""Get the current subagents configuration."""
return _subagents_config
def load_subagents_config_from_dict(config_dict: dict) -> None:
"""Load subagents configuration from a dictionary."""
global _subagents_config
_subagents_config = SubagentsAppConfig(**config_dict)
overrides_summary = {}
for name, override in _subagents_config.agents.items():
parts = []
if override.timeout_seconds is not None:
parts.append(f"timeout={override.timeout_seconds}s")
if override.max_turns is not None:
parts.append(f"max_turns={override.max_turns}")
if override.model is not None:
parts.append(f"model={override.model}")
if override.skills is not None:
parts.append(f"skills={override.skills}")
if parts:
overrides_summary[name] = ", ".join(parts)
custom_agents_names = list(_subagents_config.custom_agents.keys())
if overrides_summary or custom_agents_names:
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, per-agent overrides=%s, custom_agents=%s",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
overrides_summary or "none",
custom_agents_names or "none",
)
else:
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, no per-agent overrides",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
)
@@ -2,7 +2,7 @@
from typing import Literal
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
ContextSizeType = Literal["fraction", "tokens", "messages"]
@@ -10,6 +10,8 @@ ContextSizeType = Literal["fraction", "tokens", "messages"]
class ContextSize(BaseModel):
"""Context size specification for trigger or keep parameters."""
model_config = ConfigDict(frozen=True)
type: ContextSizeType = Field(description="Type of context size specification")
value: int | float = Field(description="Value for the context size specification")
@@ -21,6 +23,8 @@ class ContextSize(BaseModel):
class SummarizationConfig(BaseModel):
"""Configuration for automatic conversation summarization."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(
default=False,
description="Whether to enable automatic conversation summarization",
@@ -70,24 +74,3 @@ class SummarizationConfig(BaseModel):
default_factory=lambda: ["read_file", "read", "view", "cat"],
description="Tool names treated as skill file reads when preserving recently-loaded skills across summarization.",
)
# Global configuration instance
_summarization_config: SummarizationConfig = SummarizationConfig()
def get_summarization_config() -> SummarizationConfig:
"""Get the current summarization configuration."""
return _summarization_config
def set_summarization_config(config: SummarizationConfig) -> None:
"""Set the summarization configuration."""
global _summarization_config
_summarization_config = config
def load_summarization_config_from_dict(config_dict: dict) -> None:
"""Load summarization configuration from a dictionary."""
global _summarization_config
_summarization_config = SummarizationConfig(**config_dict)
@@ -1,11 +1,13 @@
"""Configuration for automatic thread title generation."""
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class TitleConfig(BaseModel):
"""Configuration for automatic thread title generation."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(
default=True,
description="Whether to enable automatic title generation",
@@ -30,24 +32,3 @@ class TitleConfig(BaseModel):
default=("Generate a concise title (max {max_words} words) for this conversation.\nUser: {user_msg}\nAssistant: {assistant_msg}\n\nReturn ONLY the title, no quotes, no explanation."),
description="Prompt template for title generation",
)
# Global configuration instance
_title_config: TitleConfig = TitleConfig()
def get_title_config() -> TitleConfig:
"""Get the current title configuration."""
return _title_config
def set_title_config(config: TitleConfig) -> None:
"""Set the title configuration."""
global _title_config
_title_config = config
def load_title_config_from_dict(config_dict: dict) -> None:
"""Load title configuration from a dictionary."""
global _title_config
_title_config = TitleConfig(**config_dict)
@@ -1,7 +1,9 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class TokenUsageConfig(BaseModel):
"""Configuration for token usage tracking."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(default=False, description="Enable token usage tracking middleware")
@@ -5,7 +5,7 @@ class ToolGroupConfig(BaseModel):
"""Config section for a tool group"""
name: str = Field(..., description="Unique name for the tool group")
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra="allow", frozen=True)
class ToolConfig(BaseModel):
@@ -17,4 +17,4 @@ class ToolConfig(BaseModel):
...,
description="Variable name of the tool provider(e.g. deerflow.sandbox.tools:bash_tool)",
)
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra="allow", frozen=True)
@@ -1,6 +1,6 @@
"""Configuration for deferred tool loading via tool_search."""
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
class ToolSearchConfig(BaseModel):
@@ -11,25 +11,9 @@ class ToolSearchConfig(BaseModel):
via the tool_search tool at runtime.
"""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(
default=False,
description="Defer tools and enable tool_search",
)
_tool_search_config: ToolSearchConfig | None = None
def get_tool_search_config() -> ToolSearchConfig:
"""Get the tool search config, loading from AppConfig if needed."""
global _tool_search_config
if _tool_search_config is None:
_tool_search_config = ToolSearchConfig()
return _tool_search_config
def load_tool_search_config_from_dict(data: dict) -> ToolSearchConfig:
"""Load tool search config from a dict (called during AppConfig loading)."""
global _tool_search_config
_tool_search_config = ToolSearchConfig.model_validate(data)
return _tool_search_config
@@ -1,7 +1,7 @@
import os
import threading
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
_config_lock = threading.Lock()
@@ -9,6 +9,8 @@ _config_lock = threading.Lock()
class LangSmithTracingConfig(BaseModel):
"""Configuration for LangSmith tracing."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(...)
api_key: str | None = Field(...)
project: str = Field(...)
@@ -26,6 +28,8 @@ class LangSmithTracingConfig(BaseModel):
class LangfuseTracingConfig(BaseModel):
"""Configuration for Langfuse tracing."""
model_config = ConfigDict(frozen=True)
enabled: bool = Field(...)
public_key: str | None = Field(...)
secret_key: str | None = Field(...)
@@ -50,6 +54,8 @@ class LangfuseTracingConfig(BaseModel):
class TracingConfig(BaseModel):
"""Tracing configuration for supported providers."""
model_config = ConfigDict(frozen=True)
langsmith: LangSmithTracingConfig = Field(...)
langfuse: LangfuseTracingConfig = Field(...)
@@ -2,7 +2,7 @@ import logging
from langchain.chat_models import BaseChatModel
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_class
from deerflow.tracing import build_tracing_callbacks
@@ -46,16 +46,23 @@ def _enable_stream_usage_by_default(model_use_path: str, model_settings_from_con
model_settings_from_config["stream_usage"] = True
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, **kwargs) -> BaseChatModel:
def create_chat_model(
name: str | None = None,
thinking_enabled: bool = False,
*,
app_config: "AppConfig",
**kwargs,
) -> BaseChatModel:
"""Create a chat model instance from the config.
Args:
name: The name of the model to create. If None, the first model in the config will be used.
app_config: Application config — required.
Returns:
A chat model instance.
"""
config = get_app_config()
config = app_config
if name is None:
name = config.models[0].name
model_config = config.get_model_config(name)
@@ -0,0 +1,13 @@
"""DeerFlow application persistence layer (SQLAlchemy 2.0 async ORM).
This module manages DeerFlow's own application data -- runs metadata,
thread ownership, cron jobs, users. It is completely separate from
LangGraph's checkpointer, which manages graph execution state.
Usage:
from deerflow.persistence import init_engine, close_engine, get_session_factory
"""
from deerflow.persistence.engine import close_engine, get_engine, get_session_factory, init_engine
__all__ = ["close_engine", "get_engine", "get_session_factory", "init_engine"]
@@ -0,0 +1,40 @@
"""SQLAlchemy declarative base with automatic to_dict support.
All DeerFlow ORM models inherit from this Base. It provides a generic
to_dict() method via SQLAlchemy's inspect() so individual models don't
need to write their own serialization logic.
LangGraph's checkpointer tables are NOT managed by this Base.
"""
from __future__ import annotations
from sqlalchemy import inspect as sa_inspect
from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
"""Base class for all DeerFlow ORM models.
Provides:
- Automatic to_dict() via SQLAlchemy column inspection.
- Standard __repr__() showing all column values.
"""
def to_dict(self, *, exclude: set[str] | None = None) -> dict:
"""Convert ORM instance to plain dict.
Uses SQLAlchemy's inspect() to iterate mapped column attributes.
Args:
exclude: Optional set of column keys to omit.
Returns:
Dict of {column_key: value} for all mapped columns.
"""
exclude = exclude or set()
return {c.key: getattr(self, c.key) for c in sa_inspect(type(self)).mapper.column_attrs if c.key not in exclude}
def __repr__(self) -> str:
cols = ", ".join(f"{c.key}={getattr(self, c.key)!r}" for c in sa_inspect(type(self)).mapper.column_attrs)
return f"{type(self).__name__}({cols})"
@@ -0,0 +1,190 @@
"""Async SQLAlchemy engine lifecycle management.
Initializes at Gateway startup, provides session factory for
repositories, disposes at shutdown.
When database.backend="memory", init_engine is a no-op and
get_session_factory() returns None. Repositories must check for
None and fall back to in-memory implementations.
"""
from __future__ import annotations
import json
import logging
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
def _json_serializer(obj: object) -> str:
"""JSON serializer with ensure_ascii=False for Chinese character support."""
return json.dumps(obj, ensure_ascii=False)
logger = logging.getLogger(__name__)
_engine: AsyncEngine | None = None
_session_factory: async_sessionmaker[AsyncSession] | None = None
async def _auto_create_postgres_db(url: str) -> None:
"""Connect to the ``postgres`` maintenance DB and CREATE DATABASE.
The target database name is extracted from *url*. The connection is
made to the default ``postgres`` database on the same server using
``AUTOCOMMIT`` isolation (CREATE DATABASE cannot run inside a
transaction).
"""
from sqlalchemy import text
from sqlalchemy.engine.url import make_url
parsed = make_url(url)
db_name = parsed.database
if not db_name:
raise ValueError("Cannot auto-create database: no database name in URL")
# Connect to the default 'postgres' database to issue CREATE DATABASE
maint_url = parsed.set(database="postgres")
maint_engine = create_async_engine(maint_url, isolation_level="AUTOCOMMIT")
try:
async with maint_engine.connect() as conn:
await conn.execute(text(f'CREATE DATABASE "{db_name}"'))
logger.info("Auto-created PostgreSQL database: %s", db_name)
finally:
await maint_engine.dispose()
async def init_engine(
backend: str,
*,
url: str = "",
echo: bool = False,
pool_size: int = 5,
sqlite_dir: str = "",
) -> None:
"""Create the async engine and session factory, then auto-create tables.
Args:
backend: "memory", "sqlite", or "postgres".
url: SQLAlchemy async URL (for sqlite/postgres).
echo: Echo SQL to log.
pool_size: Postgres connection pool size.
sqlite_dir: Directory to create for SQLite (ensured to exist).
"""
global _engine, _session_factory
if backend == "memory":
logger.info("Persistence backend=memory -- ORM engine not initialized")
return
if backend == "postgres":
try:
import asyncpg # noqa: F401
except ImportError:
raise ImportError("database.backend is set to 'postgres' but asyncpg is not installed.\nInstall it with:\n uv sync --extra postgres\nOr switch to backend: sqlite in config.yaml for single-node deployment.") from None
if backend == "sqlite":
import os
from sqlalchemy import event
os.makedirs(sqlite_dir or ".", exist_ok=True)
_engine = create_async_engine(url, echo=echo, json_serializer=_json_serializer)
# Enable WAL on every new connection. SQLite PRAGMA settings are
# per-connection, so we wire the listener instead of running PRAGMA
# once at startup. WAL gives concurrent reads + writers without
# blocking and is the standard recommendation for any production
# SQLite deployment (TC-UPG-06 in AUTH_TEST_PLAN.md). The companion
# ``synchronous=NORMAL`` is the safe-and-fast pairing — fsync only
# at WAL checkpoint boundaries instead of every commit.
# Note: we do not set PRAGMA busy_timeout here — Python's sqlite3
# driver already defaults to a 5-second busy timeout (see the
# ``timeout`` kwarg of ``sqlite3.connect``), and aiosqlite /
# SQLAlchemy's aiosqlite dialect inherit that default. Setting
# it again would be a no-op.
@event.listens_for(_engine.sync_engine, "connect")
def _enable_sqlite_wal(dbapi_conn, _record): # noqa: ARG001 — SQLAlchemy contract
cursor = dbapi_conn.cursor()
try:
cursor.execute("PRAGMA journal_mode=WAL;")
cursor.execute("PRAGMA synchronous=NORMAL;")
cursor.execute("PRAGMA foreign_keys=ON;")
finally:
cursor.close()
elif backend == "postgres":
_engine = create_async_engine(
url,
echo=echo,
pool_size=pool_size,
pool_pre_ping=True,
json_serializer=_json_serializer,
)
else:
raise ValueError(f"Unknown persistence backend: {backend!r}")
_session_factory = async_sessionmaker(_engine, expire_on_commit=False)
# Auto-create tables (dev convenience). Production should use Alembic.
from deerflow.persistence.base import Base
# Import all models so Base.metadata discovers them.
# When no models exist yet (scaffolding phase), this is a no-op.
try:
import deerflow.persistence.models # noqa: F401
except ImportError:
# Models package not yet available — tables won't be auto-created.
# This is expected during initial scaffolding or minimal installs.
logger.debug("deerflow.persistence.models not found; skipping auto-create tables")
try:
async with _engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
except Exception as exc:
if backend == "postgres" and "does not exist" in str(exc):
# Database not yet created — attempt to auto-create it, then retry.
await _auto_create_postgres_db(url)
# Rebuild engine against the now-existing database
await _engine.dispose()
_engine = create_async_engine(url, echo=echo, pool_size=pool_size, pool_pre_ping=True, json_serializer=_json_serializer)
_session_factory = async_sessionmaker(_engine, expire_on_commit=False)
async with _engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
else:
raise
logger.info("Persistence engine initialized: backend=%s", backend)
async def init_engine_from_config(config) -> None:
"""Convenience: init engine from a DatabaseConfig object."""
if config.backend == "memory":
await init_engine("memory")
return
await init_engine(
backend=config.backend,
url=config.app_sqlalchemy_url,
echo=config.echo_sql,
pool_size=config.pool_size,
sqlite_dir=config.sqlite_dir if config.backend == "sqlite" else "",
)
def get_session_factory() -> async_sessionmaker[AsyncSession] | None:
"""Return the async session factory, or None if backend=memory."""
return _session_factory
def get_engine() -> AsyncEngine | None:
"""Return the async engine, or None if not initialized."""
return _engine
async def close_engine() -> None:
"""Dispose the engine, release all connections."""
global _engine, _session_factory
if _engine is not None:
await _engine.dispose()
logger.info("Persistence engine closed")
_engine = None
_session_factory = None
@@ -0,0 +1,6 @@
"""Feedback persistence — ORM and SQL repository."""
from deerflow.persistence.feedback.model import FeedbackRow
from deerflow.persistence.feedback.sql import FeedbackRepository
__all__ = ["FeedbackRepository", "FeedbackRow"]
@@ -0,0 +1,34 @@
"""ORM model for user feedback on runs."""
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import DateTime, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
class FeedbackRow(Base):
__tablename__ = "feedback"
__table_args__ = (
UniqueConstraint("thread_id", "run_id", "user_id", name="uq_feedback_thread_run_user"),
)
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
message_id: Mapped[str | None] = mapped_column(String(64))
# message_id is an optional RunEventStore event identifier —
# allows feedback to target a specific message or the entire run
rating: Mapped[int] = mapped_column(nullable=False)
# +1 (thumbs-up) or -1 (thumbs-down)
comment: Mapped[str | None] = mapped_column(Text)
# Optional text feedback from the user
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
@@ -0,0 +1,217 @@
"""SQLAlchemy-backed feedback storage.
Each method acquires its own short-lived session.
"""
from __future__ import annotations
import uuid
from datetime import UTC, datetime
from sqlalchemy import case, func, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.feedback.model import FeedbackRow
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
class FeedbackRepository:
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._sf = session_factory
@staticmethod
def _row_to_dict(row: FeedbackRow) -> dict:
d = row.to_dict()
val = d.get("created_at")
if isinstance(val, datetime):
d["created_at"] = val.isoformat()
return d
async def create(
self,
*,
run_id: str,
thread_id: str,
rating: int,
user_id: str | None | _AutoSentinel = AUTO,
message_id: str | None = None,
comment: str | None = None,
) -> dict:
"""Create a feedback record. rating must be +1 or -1."""
if rating not in (1, -1):
raise ValueError(f"rating must be +1 or -1, got {rating}")
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.create")
row = FeedbackRow(
feedback_id=str(uuid.uuid4()),
run_id=run_id,
thread_id=thread_id,
user_id=resolved_user_id,
message_id=message_id,
rating=rating,
comment=comment,
created_at=datetime.now(UTC),
)
async with self._sf() as session:
session.add(row)
await session.commit()
await session.refresh(row)
return self._row_to_dict(row)
async def get(
self,
feedback_id: str,
*,
user_id: str | None | _AutoSentinel = AUTO,
) -> dict | None:
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.get")
async with self._sf() as session:
row = await session.get(FeedbackRow, feedback_id)
if row is None:
return None
if resolved_user_id is not None and row.user_id != resolved_user_id:
return None
return self._row_to_dict(row)
async def list_by_run(
self,
thread_id: str,
run_id: str,
*,
limit: int = 100,
user_id: str | None | _AutoSentinel = AUTO,
) -> list[dict]:
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_run")
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
if resolved_user_id is not None:
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def list_by_thread(
self,
thread_id: str,
*,
limit: int = 100,
user_id: str | None | _AutoSentinel = AUTO,
) -> list[dict]:
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread")
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
if resolved_user_id is not None:
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def delete(
self,
feedback_id: str,
*,
user_id: str | None | _AutoSentinel = AUTO,
) -> bool:
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete")
async with self._sf() as session:
row = await session.get(FeedbackRow, feedback_id)
if row is None:
return False
if resolved_user_id is not None and row.user_id != resolved_user_id:
return False
await session.delete(row)
await session.commit()
return True
async def upsert(
self,
*,
run_id: str,
thread_id: str,
rating: int,
user_id: str | None | _AutoSentinel = AUTO,
comment: str | None = None,
) -> dict:
"""Create or update feedback for (thread_id, run_id, user_id). rating must be +1 or -1."""
if rating not in (1, -1):
raise ValueError(f"rating must be +1 or -1, got {rating}")
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.upsert")
async with self._sf() as session:
stmt = select(FeedbackRow).where(
FeedbackRow.thread_id == thread_id,
FeedbackRow.run_id == run_id,
FeedbackRow.user_id == resolved_user_id,
)
result = await session.execute(stmt)
row = result.scalar_one_or_none()
if row is not None:
row.rating = rating
row.comment = comment
row.created_at = datetime.now(UTC)
else:
row = FeedbackRow(
feedback_id=str(uuid.uuid4()),
run_id=run_id,
thread_id=thread_id,
user_id=resolved_user_id,
rating=rating,
comment=comment,
created_at=datetime.now(UTC),
)
session.add(row)
await session.commit()
await session.refresh(row)
return self._row_to_dict(row)
async def delete_by_run(
self,
*,
thread_id: str,
run_id: str,
user_id: str | None | _AutoSentinel = AUTO,
) -> bool:
"""Delete the current user's feedback for a run. Returns True if a record was deleted."""
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete_by_run")
async with self._sf() as session:
stmt = select(FeedbackRow).where(
FeedbackRow.thread_id == thread_id,
FeedbackRow.run_id == run_id,
FeedbackRow.user_id == resolved_user_id,
)
result = await session.execute(stmt)
row = result.scalar_one_or_none()
if row is None:
return False
await session.delete(row)
await session.commit()
return True
async def list_by_thread_grouped(
self,
thread_id: str,
*,
user_id: str | None | _AutoSentinel = AUTO,
) -> dict[str, dict]:
"""Return feedback grouped by run_id for a thread: {run_id: feedback_dict}."""
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread_grouped")
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
if resolved_user_id is not None:
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
async with self._sf() as session:
result = await session.execute(stmt)
return {row.run_id: self._row_to_dict(row) for row in result.scalars()}
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict:
"""Aggregate feedback stats for a run using database-side counting."""
stmt = select(
func.count().label("total"),
func.coalesce(func.sum(case((FeedbackRow.rating == 1, 1), else_=0)), 0).label("positive"),
func.coalesce(func.sum(case((FeedbackRow.rating == -1, 1), else_=0)), 0).label("negative"),
).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
async with self._sf() as session:
row = (await session.execute(stmt)).one()
return {
"run_id": run_id,
"total": row.total,
"positive": row.positive,
"negative": row.negative,
}
@@ -0,0 +1,38 @@
[alembic]
script_location = %(here)s
# Default URL for offline mode / autogenerate.
# Runtime uses engine from DeerFlow config.
sqlalchemy.url = sqlite+aiosqlite:///./data/deerflow.db
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
@@ -0,0 +1,65 @@
"""Alembic environment for DeerFlow application tables.
ONLY manages DeerFlow's tables (runs, threads_meta, cron_jobs, users).
LangGraph's checkpointer tables are managed by LangGraph itself -- they
have their own schema lifecycle and must not be touched by Alembic.
"""
from __future__ import annotations
import asyncio
import logging
from logging.config import fileConfig
from alembic import context
from sqlalchemy.ext.asyncio import create_async_engine
from deerflow.persistence.base import Base
# Import all models so metadata is populated.
try:
import deerflow.persistence.models # noqa: F401 — register ORM models with Base.metadata
except ImportError:
# Models not available — migration will work with existing metadata only.
logging.getLogger(__name__).warning("Could not import deerflow.persistence.models; Alembic may not detect all tables")
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
target_metadata = Base.metadata
def run_migrations_offline() -> None:
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
render_as_batch=True,
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection):
context.configure(
connection=connection,
target_metadata=target_metadata,
render_as_batch=True, # Required for SQLite ALTER TABLE support
)
with context.begin_transaction():
context.run_migrations()
async def run_migrations_online() -> None:
connectable = create_async_engine(config.get_main_option("sqlalchemy.url"))
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
if context.is_offline_mode():
run_migrations_offline()
else:
asyncio.run(run_migrations_online())
@@ -0,0 +1,23 @@
"""ORM model registration entry point.
Importing this module ensures all ORM models are registered with
``Base.metadata`` so Alembic autogenerate detects every table.
The actual ORM classes have moved to entity-specific subpackages:
- ``deerflow.persistence.thread_meta``
- ``deerflow.persistence.run``
- ``deerflow.persistence.feedback``
- ``deerflow.persistence.user``
``RunEventRow`` remains in ``deerflow.persistence.models.run_event`` because
its storage implementation lives in ``deerflow.runtime.events.store.db`` and
there is no matching entity directory.
"""
from deerflow.persistence.feedback.model import FeedbackRow
from deerflow.persistence.models.run_event import RunEventRow
from deerflow.persistence.run.model import RunRow
from deerflow.persistence.thread_meta.model import ThreadMetaRow
from deerflow.persistence.user.model import UserRow
__all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow", "UserRow"]
@@ -0,0 +1,35 @@
"""ORM model for run events."""
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import JSON, DateTime, Index, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
class RunEventRow(Base):
__tablename__ = "run_events"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
thread_id: Mapped[str] = mapped_column(String(64), nullable=False)
run_id: Mapped[str] = mapped_column(String(64), nullable=False)
# Owner of the conversation this event belongs to. Nullable for data
# created before auth was introduced; populated by auth middleware on
# new writes and by the boot-time orphan migration on existing rows.
user_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
category: Mapped[str] = mapped_column(String(16), nullable=False)
# "message" | "trace" | "lifecycle"
content: Mapped[str] = mapped_column(Text, default="")
event_metadata: Mapped[dict] = mapped_column(JSON, default=dict)
seq: Mapped[int] = mapped_column(nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
__table_args__ = (
UniqueConstraint("thread_id", "seq", name="uq_events_thread_seq"),
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
Index("ix_events_run", "thread_id", "run_id", "seq"),
)
@@ -0,0 +1,6 @@
"""Run metadata persistence — ORM and SQL repository."""
from deerflow.persistence.run.model import RunRow
from deerflow.persistence.run.sql import RunRepository
__all__ = ["RunRepository", "RunRow"]
@@ -0,0 +1,49 @@
"""ORM model for run metadata."""
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import JSON, DateTime, Index, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
class RunRow(Base):
__tablename__ = "runs"
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
assistant_id: Mapped[str | None] = mapped_column(String(128))
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
status: Mapped[str] = mapped_column(String(20), default="pending")
# "pending" | "running" | "success" | "error" | "timeout" | "interrupted"
model_name: Mapped[str | None] = mapped_column(String(128))
multitask_strategy: Mapped[str] = mapped_column(String(20), default="reject")
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
kwargs_json: Mapped[dict] = mapped_column(JSON, default=dict)
error: Mapped[str | None] = mapped_column(Text)
# Convenience fields (for listing pages without querying RunEventStore)
message_count: Mapped[int] = mapped_column(default=0)
first_human_message: Mapped[str | None] = mapped_column(Text)
last_ai_message: Mapped[str | None] = mapped_column(Text)
# Token usage (accumulated in-memory by RunJournal, written on run completion)
total_input_tokens: Mapped[int] = mapped_column(default=0)
total_output_tokens: Mapped[int] = mapped_column(default=0)
total_tokens: Mapped[int] = mapped_column(default=0)
llm_call_count: Mapped[int] = mapped_column(default=0)
lead_agent_tokens: Mapped[int] = mapped_column(default=0)
subagent_tokens: Mapped[int] = mapped_column(default=0)
middleware_tokens: Mapped[int] = mapped_column(default=0)
# Follow-up association
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64))
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
__table_args__ = (Index("ix_runs_thread_status", "thread_id", "status"),)
@@ -0,0 +1,255 @@
"""SQLAlchemy-backed RunStore implementation.
Each method acquires and releases its own short-lived session.
Run status updates happen from background workers that may live
minutes -- we don't hold connections across long execution.
"""
from __future__ import annotations
import json
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import func, select, update
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.run.model import RunRow
from deerflow.runtime.runs.store.base import RunStore
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
class RunRepository(RunStore):
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._sf = session_factory
@staticmethod
def _safe_json(obj: Any) -> Any:
"""Ensure obj is JSON-serializable. Falls back to model_dump() or str()."""
if obj is None:
return None
if isinstance(obj, (str, int, float, bool)):
return obj
if isinstance(obj, dict):
return {k: RunRepository._safe_json(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [RunRepository._safe_json(v) for v in obj]
if hasattr(obj, "model_dump"):
try:
return obj.model_dump()
except Exception:
pass
if hasattr(obj, "dict"):
try:
return obj.dict()
except Exception:
pass
try:
json.dumps(obj)
return obj
except (TypeError, ValueError):
return str(obj)
@staticmethod
def _row_to_dict(row: RunRow) -> dict[str, Any]:
d = row.to_dict()
# Remap JSON columns to match RunStore interface
d["metadata"] = d.pop("metadata_json", {})
d["kwargs"] = d.pop("kwargs_json", {})
# Convert datetime to ISO string for consistency with MemoryRunStore
for key in ("created_at", "updated_at"):
val = d.get(key)
if isinstance(val, datetime):
d[key] = val.isoformat()
return d
async def put(
self,
run_id,
*,
thread_id,
assistant_id=None,
user_id: str | None | _AutoSentinel = AUTO,
status="pending",
multitask_strategy="reject",
metadata=None,
kwargs=None,
error=None,
created_at=None,
follow_up_to_run_id=None,
):
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.put")
now = datetime.now(UTC)
row = RunRow(
run_id=run_id,
thread_id=thread_id,
assistant_id=assistant_id,
user_id=resolved_user_id,
status=status,
multitask_strategy=multitask_strategy,
metadata_json=self._safe_json(metadata) or {},
kwargs_json=self._safe_json(kwargs) or {},
error=error,
follow_up_to_run_id=follow_up_to_run_id,
created_at=datetime.fromisoformat(created_at) if created_at else now,
updated_at=now,
)
async with self._sf() as session:
session.add(row)
await session.commit()
async def get(
self,
run_id,
*,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.get")
async with self._sf() as session:
row = await session.get(RunRow, run_id)
if row is None:
return None
if resolved_user_id is not None and row.user_id != resolved_user_id:
return None
return self._row_to_dict(row)
async def list_by_thread(
self,
thread_id,
*,
user_id: str | None | _AutoSentinel = AUTO,
limit=100,
):
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.list_by_thread")
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
if resolved_user_id is not None:
stmt = stmt.where(RunRow.user_id == resolved_user_id)
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def update_status(self, run_id, status, *, error=None):
values: dict[str, Any] = {"status": status, "updated_at": datetime.now(UTC)}
if error is not None:
values["error"] = error
async with self._sf() as session:
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit()
async def delete(
self,
run_id,
*,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.delete")
async with self._sf() as session:
row = await session.get(RunRow, run_id)
if row is None:
return
if resolved_user_id is not None and row.user_id != resolved_user_id:
return
await session.delete(row)
await session.commit()
async def list_pending(self, *, before=None):
if before is None:
before_dt = datetime.now(UTC)
elif isinstance(before, datetime):
before_dt = before
else:
before_dt = datetime.fromisoformat(before)
stmt = select(RunRow).where(RunRow.status == "pending", RunRow.created_at <= before_dt).order_by(RunRow.created_at.asc())
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def update_run_completion(
self,
run_id: str,
*,
status: str,
total_input_tokens: int = 0,
total_output_tokens: int = 0,
total_tokens: int = 0,
llm_call_count: int = 0,
lead_agent_tokens: int = 0,
subagent_tokens: int = 0,
middleware_tokens: int = 0,
message_count: int = 0,
last_ai_message: str | None = None,
first_human_message: str | None = None,
error: str | None = None,
) -> None:
"""Update status + token usage + convenience fields on run completion."""
values: dict[str, Any] = {
"status": status,
"total_input_tokens": total_input_tokens,
"total_output_tokens": total_output_tokens,
"total_tokens": total_tokens,
"llm_call_count": llm_call_count,
"lead_agent_tokens": lead_agent_tokens,
"subagent_tokens": subagent_tokens,
"middleware_tokens": middleware_tokens,
"message_count": message_count,
"updated_at": datetime.now(UTC),
}
if last_ai_message is not None:
values["last_ai_message"] = last_ai_message[:2000]
if first_human_message is not None:
values["first_human_message"] = first_human_message[:2000]
if error is not None:
values["error"] = error
async with self._sf() as session:
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit()
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
"""Aggregate token usage via a single SQL GROUP BY query."""
_completed = RunRow.status.in_(("success", "error"))
_thread = RunRow.thread_id == thread_id
stmt = (
select(
func.coalesce(RunRow.model_name, "unknown").label("model"),
func.count().label("runs"),
func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"),
func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"),
func.coalesce(func.sum(RunRow.total_output_tokens), 0).label("total_output_tokens"),
func.coalesce(func.sum(RunRow.lead_agent_tokens), 0).label("lead_agent"),
func.coalesce(func.sum(RunRow.subagent_tokens), 0).label("subagent"),
func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"),
)
.where(_thread, _completed)
.group_by(func.coalesce(RunRow.model_name, "unknown"))
)
async with self._sf() as session:
rows = (await session.execute(stmt)).all()
total_tokens = total_input = total_output = total_runs = 0
lead_agent = subagent = middleware = 0
by_model: dict[str, dict] = {}
for r in rows:
by_model[r.model] = {"tokens": r.total_tokens, "runs": r.runs}
total_tokens += r.total_tokens
total_input += r.total_input_tokens
total_output += r.total_output_tokens
total_runs += r.runs
lead_agent += r.lead_agent
subagent += r.subagent
middleware += r.middleware
return {
"total_tokens": total_tokens,
"total_input_tokens": total_input,
"total_output_tokens": total_output,
"total_runs": total_runs,
"by_model": by_model,
"by_caller": {
"lead_agent": lead_agent,
"subagent": subagent,
"middleware": middleware,
},
}
@@ -0,0 +1,38 @@
"""Thread metadata persistence — ORM, abstract store, and concrete implementations."""
from __future__ import annotations
from typing import TYPE_CHECKING
from deerflow.persistence.thread_meta.base import ThreadMetaStore
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
from deerflow.persistence.thread_meta.model import ThreadMetaRow
from deerflow.persistence.thread_meta.sql import ThreadMetaRepository
if TYPE_CHECKING:
from langgraph.store.base import BaseStore
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
__all__ = [
"MemoryThreadMetaStore",
"ThreadMetaRepository",
"ThreadMetaRow",
"ThreadMetaStore",
"make_thread_store",
]
def make_thread_store(
session_factory: async_sessionmaker[AsyncSession] | None,
store: BaseStore | None = None,
) -> ThreadMetaStore:
"""Create the appropriate ThreadMetaStore based on available backends.
Returns a SQL-backed repository when a session factory is available,
otherwise falls back to the in-memory LangGraph Store implementation.
"""
if session_factory is not None:
return ThreadMetaRepository(session_factory)
if store is None:
raise ValueError("make_thread_store requires either a session_factory (SQL) or a store (memory)")
return MemoryThreadMetaStore(store)
@@ -0,0 +1,76 @@
"""Abstract interface for thread metadata storage.
Implementations:
- ThreadMetaRepository: SQL-backed (sqlite / postgres via SQLAlchemy)
- MemoryThreadMetaStore: wraps LangGraph BaseStore (memory mode)
All mutating and querying methods accept a ``user_id`` parameter with
three-state semantics (see :mod:`deerflow.runtime.user_context`):
- ``AUTO`` (default): resolve from the request-scoped contextvar.
- Explicit ``str``: use the provided value verbatim.
- Explicit ``None``: bypass owner filtering (migration/CLI only).
"""
from __future__ import annotations
import abc
from deerflow.runtime.user_context import AUTO, _AutoSentinel
class ThreadMetaStore(abc.ABC):
@abc.abstractmethod
async def create(
self,
thread_id: str,
*,
assistant_id: str | None = None,
user_id: str | None | _AutoSentinel = AUTO,
display_name: str | None = None,
metadata: dict | None = None,
) -> dict:
pass
@abc.abstractmethod
async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None:
pass
@abc.abstractmethod
async def search(
self,
*,
metadata: dict | None = None,
status: str | None = None,
limit: int = 100,
offset: int = 0,
user_id: str | None | _AutoSentinel = AUTO,
) -> list[dict]:
pass
@abc.abstractmethod
async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
pass
@abc.abstractmethod
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
pass
@abc.abstractmethod
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
"""Merge ``metadata`` into the thread's metadata field.
Existing keys are overwritten by the new values; keys absent from
``metadata`` are preserved. No-op if the thread does not exist
or the owner check fails.
"""
pass
@abc.abstractmethod
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
"""Check if ``user_id`` has access to ``thread_id``."""
pass
@abc.abstractmethod
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
pass
@@ -0,0 +1,149 @@
"""In-memory ThreadMetaStore backed by LangGraph BaseStore.
Used when database.backend=memory. Delegates to the LangGraph Store's
``("threads",)`` namespace — the same namespace used by the Gateway
router for thread records.
"""
from __future__ import annotations
import time
from typing import Any
from langgraph.store.base import BaseStore
from deerflow.persistence.thread_meta.base import ThreadMetaStore
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
THREADS_NS: tuple[str, ...] = ("threads",)
class MemoryThreadMetaStore(ThreadMetaStore):
def __init__(self, store: BaseStore) -> None:
self._store = store
async def _get_owned_record(
self,
thread_id: str,
user_id: str | None | _AutoSentinel,
method_name: str,
) -> dict | None:
"""Fetch a record and verify ownership. Returns a mutable copy, or None."""
resolved = resolve_user_id(user_id, method_name=method_name)
item = await self._store.aget(THREADS_NS, thread_id)
if item is None:
return None
record = dict(item.value)
if resolved is not None and record.get("user_id") != resolved:
return None
return record
async def create(
self,
thread_id: str,
*,
assistant_id: str | None = None,
user_id: str | None | _AutoSentinel = AUTO,
display_name: str | None = None,
metadata: dict | None = None,
) -> dict:
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.create")
now = time.time()
record: dict[str, Any] = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"user_id": resolved_user_id,
"display_name": display_name,
"status": "idle",
"metadata": metadata or {},
"values": {},
"created_at": now,
"updated_at": now,
}
await self._store.aput(THREADS_NS, thread_id, record)
return record
async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None:
return await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.get")
async def search(
self,
*,
metadata: dict | None = None,
status: str | None = None,
limit: int = 100,
offset: int = 0,
user_id: str | None | _AutoSentinel = AUTO,
) -> list[dict]:
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.search")
filter_dict: dict[str, Any] = {}
if metadata:
filter_dict.update(metadata)
if status:
filter_dict["status"] = status
if resolved_user_id is not None:
filter_dict["user_id"] = resolved_user_id
items = await self._store.asearch(
THREADS_NS,
filter=filter_dict or None,
limit=limit,
offset=offset,
)
return [self._item_to_dict(item) for item in items]
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
item = await self._store.aget(THREADS_NS, thread_id)
if item is None:
return not require_existing
record_user_id = item.value.get("user_id")
if record_user_id is None:
return True
return record_user_id == user_id
async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_display_name")
if record is None:
return
record["display_name"] = display_name
record["updated_at"] = time.time()
await self._store.aput(THREADS_NS, thread_id, record)
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_status")
if record is None:
return
record["status"] = status
record["updated_at"] = time.time()
await self._store.aput(THREADS_NS, thread_id, record)
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_metadata")
if record is None:
return
merged = dict(record.get("metadata") or {})
merged.update(metadata)
record["metadata"] = merged
record["updated_at"] = time.time()
await self._store.aput(THREADS_NS, thread_id, record)
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete")
if record is None:
return
await self._store.adelete(THREADS_NS, thread_id)
@staticmethod
def _item_to_dict(item) -> dict[str, Any]:
"""Convert a Store SearchItem to the dict format expected by callers."""
val = item.value
return {
"thread_id": item.key,
"assistant_id": val.get("assistant_id"),
"user_id": val.get("user_id"),
"display_name": val.get("display_name"),
"status": val.get("status", "idle"),
"metadata": val.get("metadata", {}),
"created_at": str(val.get("created_at", "")),
"updated_at": str(val.get("updated_at", "")),
}
@@ -0,0 +1,23 @@
"""ORM model for thread metadata."""
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import JSON, DateTime, String
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
class ThreadMetaRow(Base):
__tablename__ = "threads_meta"
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
assistant_id: Mapped[str | None] = mapped_column(String(128), index=True)
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
display_name: Mapped[str | None] = mapped_column(String(256))
status: Mapped[str] = mapped_column(String(20), default="idle")
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
@@ -0,0 +1,217 @@
"""SQLAlchemy-backed thread metadata repository."""
from __future__ import annotations
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.thread_meta.base import ThreadMetaStore
from deerflow.persistence.thread_meta.model import ThreadMetaRow
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
class ThreadMetaRepository(ThreadMetaStore):
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._sf = session_factory
@staticmethod
def _row_to_dict(row: ThreadMetaRow) -> dict[str, Any]:
d = row.to_dict()
d["metadata"] = d.pop("metadata_json", {})
for key in ("created_at", "updated_at"):
val = d.get(key)
if isinstance(val, datetime):
d[key] = val.isoformat()
return d
async def create(
self,
thread_id: str,
*,
assistant_id: str | None = None,
user_id: str | None | _AutoSentinel = AUTO,
display_name: str | None = None,
metadata: dict | None = None,
) -> dict:
# Auto-resolve user_id from contextvar when AUTO; explicit None
# creates an orphan row (used by migration scripts).
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.create")
now = datetime.now(UTC)
row = ThreadMetaRow(
thread_id=thread_id,
assistant_id=assistant_id,
user_id=resolved_user_id,
display_name=display_name,
metadata_json=metadata or {},
created_at=now,
updated_at=now,
)
async with self._sf() as session:
session.add(row)
await session.commit()
await session.refresh(row)
return self._row_to_dict(row)
async def get(
self,
thread_id: str,
*,
user_id: str | None | _AutoSentinel = AUTO,
) -> dict | None:
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.get")
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
if row is None:
return None
# Enforce owner filter unless explicitly bypassed (user_id=None).
if resolved_user_id is not None and row.user_id != resolved_user_id:
return None
return self._row_to_dict(row)
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
"""Check if ``user_id`` has access to ``thread_id``.
Two modes — one row, two distinct semantics depending on what
the caller is about to do:
- ``require_existing=False`` (default, permissive):
Returns True for: row missing (untracked legacy thread),
``row.user_id`` is None (shared / pre-auth data),
or ``row.user_id == user_id``. Use for **read-style**
decorators where treating an untracked thread as accessible
preserves backward-compat.
- ``require_existing=True`` (strict):
Returns True **only** when the row exists AND
(``row.user_id == user_id`` OR ``row.user_id is None``).
Use for **destructive / mutating** decorators (DELETE, PATCH,
state-update) so a thread that has *already been deleted*
cannot be re-targeted by any caller — closing the
delete-idempotence cross-user gap where the row vanishing
made every other user appear to "own" it.
"""
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
if row is None:
return not require_existing
if row.user_id is None:
return True
return row.user_id == user_id
async def search(
self,
*,
metadata: dict | None = None,
status: str | None = None,
limit: int = 100,
offset: int = 0,
user_id: str | None | _AutoSentinel = AUTO,
) -> list[dict]:
"""Search threads with optional metadata and status filters.
Owner filter is enforced by default: caller must be in a user
context. Pass ``user_id=None`` to bypass (migration/CLI).
"""
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.search")
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
if resolved_user_id is not None:
stmt = stmt.where(ThreadMetaRow.user_id == resolved_user_id)
if status:
stmt = stmt.where(ThreadMetaRow.status == status)
if metadata:
# When metadata filter is active, fetch a larger window and filter
# in Python. TODO(Phase 2): use JSON DB operators (Postgres @>,
# SQLite json_extract) for server-side filtering.
stmt = stmt.limit(limit * 5 + offset)
async with self._sf() as session:
result = await session.execute(stmt)
rows = [self._row_to_dict(r) for r in result.scalars()]
rows = [r for r in rows if all(r.get("metadata", {}).get(k) == v for k, v in metadata.items())]
return rows[offset : offset + limit]
else:
stmt = stmt.limit(limit).offset(offset)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_user_id: str | None) -> bool:
"""Return True if the row exists and is owned (or filter bypassed)."""
if resolved_user_id is None:
return True # explicit bypass
row = await session.get(ThreadMetaRow, thread_id)
return row is not None and row.user_id == resolved_user_id
async def update_display_name(
self,
thread_id: str,
display_name: str,
*,
user_id: str | None | _AutoSentinel = AUTO,
) -> None:
"""Update the display_name (title) for a thread."""
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_display_name")
async with self._sf() as session:
if not await self._check_ownership(session, thread_id, resolved_user_id):
return
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
await session.commit()
async def update_status(
self,
thread_id: str,
status: str,
*,
user_id: str | None | _AutoSentinel = AUTO,
) -> None:
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_status")
async with self._sf() as session:
if not await self._check_ownership(session, thread_id, resolved_user_id):
return
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
await session.commit()
async def update_metadata(
self,
thread_id: str,
metadata: dict,
*,
user_id: str | None | _AutoSentinel = AUTO,
) -> None:
"""Merge ``metadata`` into ``metadata_json``.
Read-modify-write inside a single session/transaction so concurrent
callers see consistent state. No-op if the row does not exist or
the user_id check fails.
"""
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_metadata")
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
if row is None:
return
if resolved_user_id is not None and row.user_id != resolved_user_id:
return
merged = dict(row.metadata_json or {})
merged.update(metadata)
row.metadata_json = merged
row.updated_at = datetime.now(UTC)
await session.commit()
async def delete(
self,
thread_id: str,
*,
user_id: str | None | _AutoSentinel = AUTO,
) -> None:
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.delete")
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
if row is None:
return
if resolved_user_id is not None and row.user_id != resolved_user_id:
return
await session.delete(row)
await session.commit()
@@ -0,0 +1,12 @@
"""User storage subpackage.
Holds the ORM model for the ``users`` table. The concrete repository
implementation (``SQLiteUserRepository``) lives in the app layer
(``app.gateway.auth.repositories.sqlite``) because it converts
between the ORM row and the auth module's pydantic ``User`` class.
This keeps the harness package free of any dependency on app code.
"""
from deerflow.persistence.user.model import UserRow
__all__ = ["UserRow"]
@@ -0,0 +1,59 @@
"""ORM model for the users table.
Lives in the harness persistence package so it is picked up by
``Base.metadata.create_all()`` alongside ``threads_meta``, ``runs``,
``run_events``, and ``feedback``. Using the shared engine means:
- One SQLite/Postgres database, one connection pool
- One schema initialisation codepath
- Consistent async sessions across auth and persistence reads
"""
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import Boolean, DateTime, Index, String, text
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
class UserRow(Base):
__tablename__ = "users"
# UUIDs are stored as 36-char strings for cross-backend portability.
id: Mapped[str] = mapped_column(String(36), primary_key=True)
email: Mapped[str] = mapped_column(String(320), unique=True, nullable=False, index=True)
password_hash: Mapped[str | None] = mapped_column(String(128), nullable=True)
# "admin" | "user" — kept as plain string to avoid ALTER TABLE pain
# when new roles are introduced.
system_role: Mapped[str] = mapped_column(String(16), nullable=False, default="user")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(UTC),
)
# OAuth linkage (optional). A partial unique index enforces one
# account per (provider, oauth_id) pair, leaving NULL/NULL rows
# unconstrained so plain password accounts can coexist.
oauth_provider: Mapped[str | None] = mapped_column(String(32), nullable=True)
oauth_id: Mapped[str | None] = mapped_column(String(128), nullable=True)
# Auth lifecycle flags
needs_setup: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
token_version: Mapped[int] = mapped_column(nullable=False, default=0)
__table_args__ = (
Index(
"idx_users_oauth_identity",
"oauth_provider",
"oauth_id",
unique=True,
sqlite_where=text("oauth_provider IS NOT NULL AND oauth_id IS NOT NULL"),
),
)
@@ -5,15 +5,22 @@ Re-exports the public API of :mod:`~deerflow.runtime.runs` and
directly from ``deerflow.runtime``.
"""
from .runs import ConflictError, DisconnectMode, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
from .checkpointer import checkpointer_context, get_checkpointer, make_checkpointer, reset_checkpointer
from .runs import ConflictError, DisconnectMode, RunContext, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple
from .store import get_store, make_store, reset_store, store_context
from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge
__all__ = [
# checkpointer
"checkpointer_context",
"get_checkpointer",
"make_checkpointer",
"reset_checkpointer",
# runs
"ConflictError",
"DisconnectMode",
"RunContext",
"RunManager",
"RunRecord",
"RunStatus",
@@ -7,12 +7,12 @@ Supported backends: memory, sqlite, postgres.
Usage (e.g. FastAPI lifespan)::
from deerflow.agents.checkpointer.async_provider import make_checkpointer
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
async with make_checkpointer() as checkpointer:
app.state.checkpointer = checkpointer # InMemorySaver if not configured
For sync usage see :mod:`deerflow.agents.checkpointer.provider`.
For sync usage see :mod:`deerflow.runtime.checkpointer.provider`.
"""
from __future__ import annotations
@@ -24,12 +24,12 @@ from collections.abc import AsyncIterator
from langgraph.types import Checkpointer
from deerflow.agents.checkpointer.provider import (
from deerflow.config.app_config import AppConfig
from deerflow.runtime.checkpointer.provider import (
POSTGRES_CONN_REQUIRED,
POSTGRES_INSTALL,
SQLITE_INSTALL,
)
from deerflow.config.app_config import get_app_config
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@@ -84,23 +84,74 @@ async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
@contextlib.asynccontextmanager
async def make_checkpointer() -> AsyncIterator[Checkpointer]:
"""Async context manager that yields a checkpointer for the caller's lifetime.
Resources are opened on enter and closed on exit no global state::
async with make_checkpointer() as checkpointer:
app.state.checkpointer = checkpointer
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
if config.checkpointer is None:
async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpointer]:
"""Async context manager that constructs a checkpointer from unified DatabaseConfig."""
if db_config.backend == "memory":
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
return
async with _async_checkpointer(config.checkpointer) as saver:
yield saver
if db_config.backend == "sqlite":
try:
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
except ImportError as exc:
raise ImportError(SQLITE_INSTALL) from exc
conn_str = db_config.checkpointer_sqlite_path
ensure_sqlite_parent_dir(conn_str)
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
await saver.setup()
yield saver
return
if db_config.backend == "postgres":
try:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not db_config.postgres_url:
raise ValueError("database.postgres_url is required for the postgres backend")
async with AsyncPostgresSaver.from_conn_string(db_config.postgres_url) as saver:
await saver.setup()
yield saver
return
raise ValueError(f"Unknown database backend: {db_config.backend!r}")
@contextlib.asynccontextmanager
async def make_checkpointer(app_config: AppConfig) -> AsyncIterator[Checkpointer]:
"""Async context manager that yields a checkpointer for the caller's lifetime.
Resources are opened on enter and closed on exit -- no global state::
async with make_checkpointer(app_config) as checkpointer:
app.state.checkpointer = checkpointer
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
Priority:
1. Legacy ``checkpointer:`` config section (backward compatible)
2. Unified ``database:`` config section
3. Default InMemorySaver
"""
# Legacy: standalone checkpointer config takes precedence
if app_config.checkpointer is not None:
async with _async_checkpointer(app_config.checkpointer) as saver:
yield saver
return
# Unified database config
db_config = getattr(app_config, "database", None)
if db_config is not None and db_config.backend != "memory":
async with _async_checkpointer_from_database(db_config) as saver:
yield saver
return
# Default: in-memory
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
@@ -7,7 +7,7 @@ Supported backends: memory, sqlite, postgres.
Usage::
from deerflow.agents.checkpointer.provider import get_checkpointer, checkpointer_context
from deerflow.runtime.checkpointer.provider import get_checkpointer, checkpointer_context
# Singleton — reused across calls, closed on process exit
cp = get_checkpointer()
@@ -25,7 +25,7 @@ from collections.abc import Iterator
from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
@@ -100,10 +100,13 @@ _checkpointer: Checkpointer | None = None
_checkpointer_ctx = None # open context manager keeping the connection alive
def get_checkpointer() -> Checkpointer:
def get_checkpointer(app_config: AppConfig) -> Checkpointer:
"""Return the global sync checkpointer singleton, creating it on first call.
Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
Returns an ``InMemorySaver`` only when ``checkpointer`` is explicitly
absent from config.yaml. Any other failure (missing config, invalid
backend, connection error) propagates silent degradation to in-memory
would drop persistent-run state on process restart.
Raises:
ImportError: If the required package for the configured backend is not installed.
@@ -114,25 +117,7 @@ def get_checkpointer() -> Checkpointer:
if _checkpointer is not None:
return _checkpointer
# Ensure app config is loaded before checking checkpointer config
# This prevents returning InMemorySaver when config.yaml actually has a checkpointer section
# but hasn't been loaded yet
from deerflow.config.app_config import _app_config
from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config()
if config is None and _app_config is None:
# Only load app config lazily when neither the app config nor an explicit
# checkpointer config has been initialized yet. This keeps tests that
# intentionally set the global checkpointer config isolated from any
# ambient config.yaml on disk.
try:
get_app_config()
except FileNotFoundError:
# In test environments without config.yaml, this is expected.
pass
config = get_checkpointer_config()
config = app_config.checkpointer
if config is None:
from langgraph.checkpoint.memory import InMemorySaver
@@ -168,25 +153,23 @@ def reset_checkpointer() -> None:
@contextlib.contextmanager
def checkpointer_context() -> Iterator[Checkpointer]:
def checkpointer_context(app_config: AppConfig) -> Iterator[Checkpointer]:
"""Sync context manager that yields a checkpointer and cleans up on exit.
Unlike :func:`get_checkpointer`, this does **not** cache the instance
each ``with`` block creates and destroys its own connection. Use it in
CLI scripts or tests where you want deterministic cleanup::
with checkpointer_context() as cp:
with checkpointer_context(app_config) as cp:
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
if config.checkpointer is None:
if app_config.checkpointer is None:
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
return
with _sync_checkpointer_cm(config.checkpointer) as saver:
with _sync_checkpointer_cm(app_config.checkpointer) as saver:
yield saver
@@ -0,0 +1,134 @@
"""Pure functions to convert LangChain message objects to OpenAI Chat Completions format.
Used by RunJournal to build content dicts for event storage.
"""
from __future__ import annotations
import json
from typing import Any
_ROLE_MAP = {
"human": "user",
"ai": "assistant",
"system": "system",
"tool": "tool",
}
def langchain_to_openai_message(message: Any) -> dict:
"""Convert a single LangChain BaseMessage to an OpenAI message dict.
Handles:
- HumanMessage → {"role": "user", "content": "..."}
- AIMessage (text only) → {"role": "assistant", "content": "..."}
- AIMessage (with tool_calls) → {"role": "assistant", "content": null, "tool_calls": [...]}
- AIMessage (text + tool_calls) → both content and tool_calls present
- AIMessage (list content / multimodal) → content preserved as list
- SystemMessage → {"role": "system", "content": "..."}
- ToolMessage → {"role": "tool", "tool_call_id": "...", "content": "..."}
"""
msg_type = getattr(message, "type", "")
role = _ROLE_MAP.get(msg_type, msg_type)
content = getattr(message, "content", "")
if role == "tool":
return {
"role": "tool",
"tool_call_id": getattr(message, "tool_call_id", ""),
"content": content,
}
if role == "assistant":
tool_calls = getattr(message, "tool_calls", None) or []
result: dict = {"role": "assistant"}
if tool_calls:
openai_tool_calls = []
for tc in tool_calls:
args = tc.get("args", {})
openai_tool_calls.append(
{
"id": tc.get("id", ""),
"type": "function",
"function": {
"name": tc.get("name", ""),
"arguments": json.dumps(args) if not isinstance(args, str) else args,
},
}
)
# If no text content, set content to null per OpenAI spec
result["content"] = content if (isinstance(content, list) and content) or (isinstance(content, str) and content) else None
result["tool_calls"] = openai_tool_calls
else:
result["content"] = content
return result
# user / system / unknown
return {"role": role, "content": content}
def _infer_finish_reason(message: Any) -> str:
"""Infer OpenAI finish_reason from an AIMessage.
Returns "tool_calls" if tool_calls present, else looks in
response_metadata.finish_reason, else returns "stop".
"""
tool_calls = getattr(message, "tool_calls", None) or []
if tool_calls:
return "tool_calls"
resp_meta = getattr(message, "response_metadata", None) or {}
if isinstance(resp_meta, dict):
finish = resp_meta.get("finish_reason")
if finish:
return finish
return "stop"
def langchain_to_openai_completion(message: Any) -> dict:
"""Convert an AIMessage and its metadata to an OpenAI completion response dict.
Returns:
{
"id": message.id,
"model": message.response_metadata.get("model_name"),
"choices": [{"index": 0, "message": <openai_message>, "finish_reason": <inferred>}],
"usage": {"prompt_tokens": ..., "completion_tokens": ..., "total_tokens": ...} or None,
}
"""
resp_meta = getattr(message, "response_metadata", None) or {}
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
openai_msg = langchain_to_openai_message(message)
finish_reason = _infer_finish_reason(message)
usage_metadata = getattr(message, "usage_metadata", None)
if usage_metadata is not None:
input_tokens = usage_metadata.get("input_tokens", 0) or 0
output_tokens = usage_metadata.get("output_tokens", 0) or 0
usage: dict | None = {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
}
else:
usage = None
return {
"id": getattr(message, "id", None),
"model": model_name,
"choices": [
{
"index": 0,
"message": openai_msg,
"finish_reason": finish_reason,
}
],
"usage": usage,
}
def langchain_messages_to_openai(messages: list) -> list[dict]:
"""Convert a list of LangChain BaseMessages to OpenAI message dicts."""
return [langchain_to_openai_message(m) for m in messages]
@@ -0,0 +1,4 @@
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.events.store.memory import MemoryRunEventStore
__all__ = ["MemoryRunEventStore", "RunEventStore"]
@@ -0,0 +1,26 @@
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.events.store.memory import MemoryRunEventStore
def make_run_event_store(config=None) -> RunEventStore:
"""Create a RunEventStore based on run_events.backend configuration."""
if config is None or config.backend == "memory":
return MemoryRunEventStore()
if config.backend == "db":
from deerflow.persistence.engine import get_session_factory
sf = get_session_factory()
if sf is None:
# database.backend=memory but run_events.backend=db -> fallback
return MemoryRunEventStore()
from deerflow.runtime.events.store.db import DbRunEventStore
return DbRunEventStore(sf, max_trace_content=config.max_trace_content)
if config.backend == "jsonl":
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
return JsonlRunEventStore()
raise ValueError(f"Unknown run_events backend: {config.backend!r}")
__all__ = ["MemoryRunEventStore", "RunEventStore", "make_run_event_store"]
@@ -0,0 +1,109 @@
"""Abstract interface for run event storage.
RunEventStore is the unified storage interface for run event streams.
Messages (frontend display) and execution traces (debugging/audit) go
through the same interface, distinguished by the ``category`` field.
Implementations:
- MemoryRunEventStore: in-memory dict (development, tests)
- Future: DB-backed store (SQLAlchemy ORM), JSONL file store
"""
from __future__ import annotations
import abc
class RunEventStore(abc.ABC):
"""Run event stream storage interface.
All implementations must guarantee:
1. put() events are retrievable in subsequent queries
2. seq is strictly increasing within the same thread
3. list_messages() only returns category="message" events
4. list_events() returns all events for the specified run
5. Returned dicts match the RunEvent field structure
"""
@abc.abstractmethod
async def put(
self,
*,
thread_id: str,
run_id: str,
event_type: str,
category: str,
content: str | dict = "",
metadata: dict | None = None,
created_at: str | None = None,
) -> dict:
"""Write an event, auto-assign seq, return the complete record."""
@abc.abstractmethod
async def put_batch(self, events: list[dict]) -> list[dict]:
"""Batch-write events. Used by RunJournal flush buffer.
Each dict's keys match put()'s keyword arguments.
Returns complete records with seq assigned.
"""
@abc.abstractmethod
async def list_messages(
self,
thread_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
) -> list[dict]:
"""Return displayable messages (category=message) for a thread, ordered by seq ascending.
Supports bidirectional cursor pagination:
- before_seq: return the last ``limit`` records with seq < before_seq (ascending)
- after_seq: return the first ``limit`` records with seq > after_seq (ascending)
- neither: return the latest ``limit`` records (ascending)
"""
@abc.abstractmethod
async def list_events(
self,
thread_id: str,
run_id: str,
*,
event_types: list[str] | None = None,
limit: int = 500,
) -> list[dict]:
"""Return the full event stream for a run, ordered by seq ascending.
Optionally filter by event_types.
"""
@abc.abstractmethod
async def list_messages_by_run(
self,
thread_id: str,
run_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
) -> list[dict]:
"""Return displayable messages (category=message) for a specific run, ordered by seq ascending.
Supports bidirectional cursor pagination:
- after_seq: return the first ``limit`` records with seq > after_seq (ascending)
- before_seq: return the last ``limit`` records with seq < before_seq (ascending)
- neither: return the latest ``limit`` records (ascending)
"""
@abc.abstractmethod
async def count_messages(self, thread_id: str) -> int:
"""Count displayable messages (category=message) in a thread."""
@abc.abstractmethod
async def delete_by_thread(self, thread_id: str) -> int:
"""Delete all events for a thread. Return the number of deleted events."""
@abc.abstractmethod
async def delete_by_run(self, thread_id: str, run_id: str) -> int:
"""Delete all events for a specific run. Return the number of deleted events."""
@@ -0,0 +1,286 @@
"""SQLAlchemy-backed RunEventStore implementation.
Persists events to the ``run_events`` table. Trace content is truncated
at ``max_trace_content`` bytes to avoid bloating the database.
"""
from __future__ import annotations
import json
import logging
from datetime import UTC, datetime
from sqlalchemy import delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.run_event import RunEventRow
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_user_id
logger = logging.getLogger(__name__)
class DbRunEventStore(RunEventStore):
def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, max_trace_content: int = 10240):
self._sf = session_factory
self._max_trace_content = max_trace_content
@staticmethod
def _row_to_dict(row: RunEventRow) -> dict:
d = row.to_dict()
d["metadata"] = d.pop("event_metadata", {})
val = d.get("created_at")
if isinstance(val, datetime):
d["created_at"] = val.isoformat()
d.pop("id", None)
# Restore dict content that was JSON-serialized on write
raw = d.get("content", "")
if isinstance(raw, str) and d.get("metadata", {}).get("content_is_dict"):
try:
d["content"] = json.loads(raw)
except (json.JSONDecodeError, ValueError):
# Content looked like JSON (content_is_dict flag) but failed to parse;
# keep the raw string as-is.
logger.debug("Failed to deserialize content as JSON for event seq=%s", d.get("seq"))
return d
def _truncate_trace(self, category: str, content: str | dict, metadata: dict | None) -> tuple[str | dict, dict]:
if category == "trace":
text = json.dumps(content, default=str, ensure_ascii=False) if isinstance(content, dict) else content
encoded = text.encode("utf-8")
if len(encoded) > self._max_trace_content:
# Truncate by bytes, then decode back (may cut a multi-byte char, so use errors="ignore")
content = encoded[: self._max_trace_content].decode("utf-8", errors="ignore")
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
return content, metadata or {}
@staticmethod
def _user_id_from_context() -> str | None:
"""Soft read of user_id from contextvar for write paths.
Returns ``None`` (no filter / no stamp) if contextvar is unset,
which is the expected case for background worker writes. HTTP
request writes will have the contextvar set by auth middleware
and get their user_id stamped automatically.
Coerces ``user.id`` to ``str`` at the boundary: ``User.id`` is
typed as ``UUID`` by the auth layer, but ``run_events.user_id``
is ``VARCHAR(64)`` and aiosqlite cannot bind a raw UUID object
to a VARCHAR column ("type 'UUID' is not supported") — the
INSERT would silently roll back and the worker would hang.
"""
user = get_current_user()
return str(user.id) if user is not None else None
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
"""Write a single event — low-frequency path only.
This opens a dedicated transaction with a FOR UPDATE lock to
assign a monotonic *seq*. For high-throughput writes use
:meth:`put_batch`, which acquires the lock once for the whole
batch. Currently the only caller is ``worker.run_agent`` for
the initial ``human_message`` event (once per run).
"""
content, metadata = self._truncate_trace(category, content, metadata)
if isinstance(content, dict):
db_content = json.dumps(content, default=str, ensure_ascii=False)
metadata = {**(metadata or {}), "content_is_dict": True}
else:
db_content = content
user_id = self._user_id_from_context()
async with self._sf() as session:
async with session.begin():
# Use FOR UPDATE to serialize seq assignment within a thread.
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
# the UNIQUE(thread_id, seq) constraint catches races there.
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
seq = (max_seq or 0) + 1
row = RunEventRow(
thread_id=thread_id,
run_id=run_id,
user_id=user_id,
event_type=event_type,
category=category,
content=db_content,
event_metadata=metadata,
seq=seq,
created_at=datetime.fromisoformat(created_at) if created_at else datetime.now(UTC),
)
session.add(row)
return self._row_to_dict(row)
async def put_batch(self, events):
if not events:
return []
user_id = self._user_id_from_context()
async with self._sf() as session:
async with session.begin():
# Get max seq for the thread (assume all events in batch belong to same thread).
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
# the UNIQUE(thread_id, seq) constraint catches races there.
thread_id = events[0]["thread_id"]
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
seq = max_seq or 0
rows = []
for e in events:
seq += 1
content = e.get("content", "")
category = e.get("category", "trace")
metadata = e.get("metadata")
content, metadata = self._truncate_trace(category, content, metadata)
if isinstance(content, dict):
db_content = json.dumps(content, default=str, ensure_ascii=False)
metadata = {**(metadata or {}), "content_is_dict": True}
else:
db_content = content
row = RunEventRow(
thread_id=e["thread_id"],
run_id=e["run_id"],
user_id=e.get("user_id", user_id),
event_type=e["event_type"],
category=category,
content=db_content,
event_metadata=metadata,
seq=seq,
created_at=datetime.fromisoformat(e["created_at"]) if e.get("created_at") else datetime.now(UTC),
)
session.add(row)
rows.append(row)
return [self._row_to_dict(r) for r in rows]
async def list_messages(
self,
thread_id,
*,
limit=50,
before_seq=None,
after_seq=None,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages")
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
if before_seq is not None:
stmt = stmt.where(RunEventRow.seq < before_seq)
if after_seq is not None:
stmt = stmt.where(RunEventRow.seq > after_seq)
if after_seq is not None:
# Forward pagination: first `limit` records after cursor
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
else:
# before_seq or default (latest): take last `limit` records, return ascending
stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
rows = list(result.scalars())
return [self._row_to_dict(r) for r in reversed(rows)]
async def list_events(
self,
thread_id,
run_id,
*,
event_types=None,
limit=500,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_events")
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
if event_types:
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def list_messages_by_run(
self,
thread_id,
run_id,
*,
limit=50,
before_seq=None,
after_seq=None,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages_by_run")
stmt = select(RunEventRow).where(
RunEventRow.thread_id == thread_id,
RunEventRow.run_id == run_id,
RunEventRow.category == "message",
)
if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
if before_seq is not None:
stmt = stmt.where(RunEventRow.seq < before_seq)
if after_seq is not None:
stmt = stmt.where(RunEventRow.seq > after_seq)
if after_seq is not None:
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
else:
stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
rows = list(result.scalars())
return [self._row_to_dict(r) for r in reversed(rows)]
async def count_messages(
self,
thread_id,
*,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.count_messages")
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
async with self._sf() as session:
return await session.scalar(stmt) or 0
async def delete_by_thread(
self,
thread_id,
*,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_thread")
async with self._sf() as session:
count_conditions = [RunEventRow.thread_id == thread_id]
if resolved_user_id is not None:
count_conditions.append(RunEventRow.user_id == resolved_user_id)
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
count = await session.scalar(count_stmt) or 0
if count > 0:
await session.execute(delete(RunEventRow).where(*count_conditions))
await session.commit()
return count
async def delete_by_run(
self,
thread_id,
run_id,
*,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_run")
async with self._sf() as session:
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
if resolved_user_id is not None:
count_conditions.append(RunEventRow.user_id == resolved_user_id)
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
count = await session.scalar(count_stmt) or 0
if count > 0:
await session.execute(delete(RunEventRow).where(*count_conditions))
await session.commit()
return count
@@ -0,0 +1,187 @@
"""JSONL file-backed RunEventStore implementation.
Each run's events are stored in a single file:
``.deer-flow/threads/{thread_id}/runs/{run_id}.jsonl``
All categories (message, trace, lifecycle) are in the same file.
This backend is suitable for lightweight single-node deployments.
Known trade-off: ``list_messages()`` must scan all run files for a
thread since messages from multiple runs need unified seq ordering.
``list_events()`` reads only one file -- the fast path.
"""
from __future__ import annotations
import json
import logging
import re
from datetime import UTC, datetime
from pathlib import Path
from deerflow.runtime.events.store.base import RunEventStore
logger = logging.getLogger(__name__)
_SAFE_ID_PATTERN = re.compile(r"^[A-Za-z0-9_\-]+$")
class JsonlRunEventStore(RunEventStore):
def __init__(self, base_dir: str | Path | None = None):
self._base_dir = Path(base_dir) if base_dir else Path(".deer-flow")
self._seq_counters: dict[str, int] = {} # thread_id -> current max seq
@staticmethod
def _validate_id(value: str, label: str) -> str:
"""Validate that an ID is safe for use in filesystem paths."""
if not value or not _SAFE_ID_PATTERN.match(value):
raise ValueError(f"Invalid {label}: must be alphanumeric/dash/underscore, got {value!r}")
return value
def _thread_dir(self, thread_id: str) -> Path:
self._validate_id(thread_id, "thread_id")
return self._base_dir / "threads" / thread_id / "runs"
def _run_file(self, thread_id: str, run_id: str) -> Path:
self._validate_id(run_id, "run_id")
return self._thread_dir(thread_id) / f"{run_id}.jsonl"
def _next_seq(self, thread_id: str) -> int:
self._seq_counters[thread_id] = self._seq_counters.get(thread_id, 0) + 1
return self._seq_counters[thread_id]
def _ensure_seq_loaded(self, thread_id: str) -> None:
"""Load max seq from existing files if not yet cached."""
if thread_id in self._seq_counters:
return
max_seq = 0
thread_dir = self._thread_dir(thread_id)
if thread_dir.exists():
for f in thread_dir.glob("*.jsonl"):
for line in f.read_text(encoding="utf-8").strip().splitlines():
try:
record = json.loads(line)
max_seq = max(max_seq, record.get("seq", 0))
except json.JSONDecodeError:
logger.debug("Skipping malformed JSONL line in %s", f)
continue
self._seq_counters[thread_id] = max_seq
def _write_record(self, record: dict) -> None:
path = self._run_file(record["thread_id"], record["run_id"])
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "a", encoding="utf-8") as f:
f.write(json.dumps(record, default=str, ensure_ascii=False) + "\n")
def _read_thread_events(self, thread_id: str) -> list[dict]:
"""Read all events for a thread, sorted by seq."""
events = []
thread_dir = self._thread_dir(thread_id)
if not thread_dir.exists():
return events
for f in sorted(thread_dir.glob("*.jsonl")):
for line in f.read_text(encoding="utf-8").strip().splitlines():
if not line:
continue
try:
events.append(json.loads(line))
except json.JSONDecodeError:
logger.debug("Skipping malformed JSONL line in %s", f)
continue
events.sort(key=lambda e: e.get("seq", 0))
return events
def _read_run_events(self, thread_id: str, run_id: str) -> list[dict]:
"""Read events for a specific run file."""
path = self._run_file(thread_id, run_id)
if not path.exists():
return []
events = []
for line in path.read_text(encoding="utf-8").strip().splitlines():
if not line:
continue
try:
events.append(json.loads(line))
except json.JSONDecodeError:
logger.debug("Skipping malformed JSONL line in %s", path)
continue
events.sort(key=lambda e: e.get("seq", 0))
return events
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None):
self._ensure_seq_loaded(thread_id)
seq = self._next_seq(thread_id)
record = {
"thread_id": thread_id,
"run_id": run_id,
"event_type": event_type,
"category": category,
"content": content,
"metadata": metadata or {},
"seq": seq,
"created_at": created_at or datetime.now(UTC).isoformat(),
}
self._write_record(record)
return record
async def put_batch(self, events):
if not events:
return []
results = []
for ev in events:
record = await self.put(**ev)
results.append(record)
return results
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
all_events = self._read_thread_events(thread_id)
messages = [e for e in all_events if e.get("category") == "message"]
if before_seq is not None:
messages = [e for e in messages if e["seq"] < before_seq]
return messages[-limit:]
elif after_seq is not None:
messages = [e for e in messages if e["seq"] > after_seq]
return messages[:limit]
else:
return messages[-limit:]
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
events = self._read_run_events(thread_id, run_id)
if event_types is not None:
events = [e for e in events if e.get("event_type") in event_types]
return events[:limit]
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
events = self._read_run_events(thread_id, run_id)
filtered = [e for e in events if e.get("category") == "message"]
if before_seq is not None:
filtered = [e for e in filtered if e.get("seq", 0) < before_seq]
if after_seq is not None:
filtered = [e for e in filtered if e.get("seq", 0) > after_seq]
if after_seq is not None:
return filtered[:limit]
else:
return filtered[-limit:] if len(filtered) > limit else filtered
async def count_messages(self, thread_id):
all_events = self._read_thread_events(thread_id)
return sum(1 for e in all_events if e.get("category") == "message")
async def delete_by_thread(self, thread_id):
all_events = self._read_thread_events(thread_id)
count = len(all_events)
thread_dir = self._thread_dir(thread_id)
if thread_dir.exists():
for f in thread_dir.glob("*.jsonl"):
f.unlink()
self._seq_counters.pop(thread_id, None)
return count
async def delete_by_run(self, thread_id, run_id):
events = self._read_run_events(thread_id, run_id)
count = len(events)
path = self._run_file(thread_id, run_id)
if path.exists():
path.unlink()
return count
@@ -0,0 +1,128 @@
"""In-memory RunEventStore. Used when run_events.backend=memory (default) and in tests.
Thread-safe for single-process async usage (no threading locks needed
since all mutations happen within the same event loop).
"""
from __future__ import annotations
from datetime import UTC, datetime
from deerflow.runtime.events.store.base import RunEventStore
class MemoryRunEventStore(RunEventStore):
def __init__(self) -> None:
self._events: dict[str, list[dict]] = {} # thread_id -> sorted event list
self._seq_counters: dict[str, int] = {} # thread_id -> last assigned seq
def _next_seq(self, thread_id: str) -> int:
current = self._seq_counters.get(thread_id, 0)
next_val = current + 1
self._seq_counters[thread_id] = next_val
return next_val
def _put_one(
self,
*,
thread_id: str,
run_id: str,
event_type: str,
category: str,
content: str | dict = "",
metadata: dict | None = None,
created_at: str | None = None,
) -> dict:
seq = self._next_seq(thread_id)
record = {
"thread_id": thread_id,
"run_id": run_id,
"event_type": event_type,
"category": category,
"content": content,
"metadata": metadata or {},
"seq": seq,
"created_at": created_at or datetime.now(UTC).isoformat(),
}
self._events.setdefault(thread_id, []).append(record)
return record
async def put(
self,
*,
thread_id,
run_id,
event_type,
category,
content="",
metadata=None,
created_at=None,
):
return self._put_one(
thread_id=thread_id,
run_id=run_id,
event_type=event_type,
category=category,
content=content,
metadata=metadata,
created_at=created_at,
)
async def put_batch(self, events):
results = []
for ev in events:
record = self._put_one(**ev)
results.append(record)
return results
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
all_events = self._events.get(thread_id, [])
messages = [e for e in all_events if e["category"] == "message"]
if before_seq is not None:
messages = [e for e in messages if e["seq"] < before_seq]
# Take the last `limit` records
return messages[-limit:]
elif after_seq is not None:
messages = [e for e in messages if e["seq"] > after_seq]
return messages[:limit]
else:
# Return the latest `limit` records, ascending
return messages[-limit:]
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
all_events = self._events.get(thread_id, [])
filtered = [e for e in all_events if e["run_id"] == run_id]
if event_types is not None:
filtered = [e for e in filtered if e["event_type"] in event_types]
return filtered[:limit]
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
all_events = self._events.get(thread_id, [])
filtered = [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"]
if before_seq is not None:
filtered = [e for e in filtered if e["seq"] < before_seq]
if after_seq is not None:
filtered = [e for e in filtered if e["seq"] > after_seq]
if after_seq is not None:
return filtered[:limit]
else:
return filtered[-limit:] if len(filtered) > limit else filtered
async def count_messages(self, thread_id):
all_events = self._events.get(thread_id, [])
return sum(1 for e in all_events if e["category"] == "message")
async def delete_by_thread(self, thread_id):
events = self._events.pop(thread_id, [])
self._seq_counters.pop(thread_id, None)
return len(events)
async def delete_by_run(self, thread_id, run_id):
all_events = self._events.get(thread_id, [])
if not all_events:
return 0
remaining = [e for e in all_events if e["run_id"] != run_id]
removed = len(all_events) - len(remaining)
self._events[thread_id] = remaining
return removed
@@ -0,0 +1,497 @@
"""Run event capture via LangChain callbacks.
RunJournal sits between LangChain's callback mechanism and the pluggable
RunEventStore. It standardizes callback data into RunEvent records and
handles token usage accumulation.
Key design decisions:
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
- on_chat_model_start captures structured prompts as llm_request (OpenAI format)
- on_llm_end emits llm_response in OpenAI Chat Completions format
- Token usage accumulated in memory, written to RunRow on run completion
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
"""
from __future__ import annotations
import asyncio
import logging
import time
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler
if TYPE_CHECKING:
from deerflow.runtime.events.store.base import RunEventStore
logger = logging.getLogger(__name__)
class RunJournal(BaseCallbackHandler):
"""LangChain callback handler that captures events to RunEventStore."""
def __init__(
self,
run_id: str,
thread_id: str,
event_store: RunEventStore,
*,
track_token_usage: bool = True,
flush_threshold: int = 20,
):
super().__init__()
self.run_id = run_id
self.thread_id = thread_id
self._store = event_store
self._track_tokens = track_token_usage
self._flush_threshold = flush_threshold
# Write buffer
self._buffer: list[dict] = []
self._pending_flush_tasks: set[asyncio.Task[None]] = set()
# Token accumulators
self._total_input_tokens = 0
self._total_output_tokens = 0
self._total_tokens = 0
self._llm_call_count = 0
self._lead_agent_tokens = 0
self._subagent_tokens = 0
self._middleware_tokens = 0
# Convenience fields
self._last_ai_msg: str | None = None
self._first_human_msg: str | None = None
self._msg_count = 0
# Latency tracking
self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time
# LLM request/response tracking
self._llm_call_index = 0
self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages
self._cached_models: dict[str, str] = {} # langchain run_id -> model name
# Tool call ID cache
self._tool_call_ids: dict[str, str] = {} # langchain run_id -> tool_call_id
# -- Lifecycle callbacks --
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
if kwargs.get("parent_run_id") is not None:
return
self._put(
event_type="run_start",
category="lifecycle",
metadata={"input_preview": str(inputs)[:500]},
)
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
if kwargs.get("parent_run_id") is not None:
return
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
self._flush_sync()
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
if kwargs.get("parent_run_id") is not None:
return
self._put(
event_type="run_error",
category="lifecycle",
content=str(error),
metadata={"error_type": type(error).__name__},
)
self._flush_sync()
# -- LLM callbacks --
def on_chat_model_start(self, serialized: dict, messages: list[list], *, run_id: UUID, **kwargs: Any) -> None:
"""Capture structured prompt messages for llm_request event."""
from deerflow.runtime.converters import langchain_messages_to_openai
rid = str(run_id)
self._llm_start_times[rid] = time.monotonic()
self._llm_call_index += 1
model_name = serialized.get("name", "")
self._cached_models[rid] = model_name
# Convert the first message list (LangChain passes list-of-lists)
prompt_msgs = messages[0] if messages else []
openai_msgs = langchain_messages_to_openai(prompt_msgs)
self._cached_prompts[rid] = openai_msgs
caller = self._identify_caller(kwargs)
self._put(
event_type="llm_request",
category="trace",
content={"model": model_name, "messages": openai_msgs},
metadata={"caller": caller, "llm_call_index": self._llm_call_index},
)
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
# Fallback: on_chat_model_start is preferred. This just tracks latency.
self._llm_start_times[str(run_id)] = time.monotonic()
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
from deerflow.runtime.converters import langchain_to_openai_completion
try:
message = response.generations[0][0].message
except (IndexError, AttributeError):
logger.debug("on_llm_end: could not extract message from response")
return
caller = self._identify_caller(kwargs)
# Latency
rid = str(run_id)
start = self._llm_start_times.pop(rid, None)
latency_ms = int((time.monotonic() - start) * 1000) if start else None
# Token usage from message
usage = getattr(message, "usage_metadata", None)
usage_dict = dict(usage) if usage else {}
# Resolve call index
call_index = self._llm_call_index
if rid not in self._cached_prompts:
# Fallback: on_chat_model_start was not called
self._llm_call_index += 1
call_index = self._llm_call_index
# Clean up caches
self._cached_prompts.pop(rid, None)
self._cached_models.pop(rid, None)
# Trace event: llm_response (OpenAI completion format)
content = getattr(message, "content", "")
self._put(
event_type="llm_response",
category="trace",
content=langchain_to_openai_completion(message),
metadata={
"caller": caller,
"usage": usage_dict,
"latency_ms": latency_ms,
"llm_call_index": call_index,
},
)
# Message events: only lead_agent gets message-category events.
# Content uses message.model_dump() to align with checkpoint format.
tool_calls = getattr(message, "tool_calls", None) or []
if caller == "lead_agent":
resp_meta = getattr(message, "response_metadata", None) or {}
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
if tool_calls:
# ai_tool_call: agent decided to use tools
self._put(
event_type="ai_tool_call",
category="message",
content=message.model_dump(),
metadata={"model_name": model_name, "finish_reason": "tool_calls"},
)
elif isinstance(content, str) and content:
# ai_message: final text reply
self._put(
event_type="ai_message",
category="message",
content=message.model_dump(),
metadata={"model_name": model_name, "finish_reason": "stop"},
)
self._last_ai_msg = content
self._msg_count += 1
# Token accumulation
if self._track_tokens:
input_tk = usage_dict.get("input_tokens", 0) or 0
output_tk = usage_dict.get("output_tokens", 0) or 0
total_tk = usage_dict.get("total_tokens", 0) or 0
if total_tk == 0:
total_tk = input_tk + output_tk
if total_tk > 0:
self._total_input_tokens += input_tk
self._total_output_tokens += output_tk
self._total_tokens += total_tk
self._llm_call_count += 1
if caller.startswith("subagent:"):
self._subagent_tokens += total_tk
elif caller.startswith("middleware:"):
self._middleware_tokens += total_tk
else:
self._lead_agent_tokens += total_tk
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
self._llm_start_times.pop(str(run_id), None)
self._put(event_type="llm_error", category="trace", content=str(error))
# -- Tool callbacks --
def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None:
tool_call_id = kwargs.get("tool_call_id")
if tool_call_id:
self._tool_call_ids[str(run_id)] = tool_call_id
self._put(
event_type="tool_start",
category="trace",
metadata={
"tool_name": serialized.get("name", ""),
"tool_call_id": tool_call_id,
"args": str(input_str)[:2000],
},
)
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
from langchain_core.messages import ToolMessage
from langgraph.types import Command
# Tools that update graph state return a ``Command`` (e.g.
# ``present_files``). LangGraph later unwraps the inner ToolMessage
# into checkpoint state, so to stay checkpoint-aligned we must
# extract it here rather than storing ``str(Command(...))``.
if isinstance(output, Command):
update = getattr(output, "update", None) or {}
inner_msgs = update.get("messages") if isinstance(update, dict) else None
if isinstance(inner_msgs, list):
inner_tool_msg = next((m for m in inner_msgs if isinstance(m, ToolMessage)), None)
if inner_tool_msg is not None:
output = inner_tool_msg
# Extract fields from ToolMessage object when LangChain provides one.
# LangChain's _format_output wraps tool results into a ToolMessage
# with tool_call_id, name, status, and artifact — more complete than
# what kwargs alone provides.
if isinstance(output, ToolMessage):
tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = output.name or kwargs.get("name", "")
status = getattr(output, "status", "success") or "success"
content_str = output.content if isinstance(output.content, str) else str(output.content)
# Use model_dump() for checkpoint-aligned message content.
# Override tool_call_id if it was resolved from cache.
msg_content = output.model_dump()
if msg_content.get("tool_call_id") != tool_call_id:
msg_content["tool_call_id"] = tool_call_id
else:
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
status = "success"
content_str = str(output)
# Construct checkpoint-aligned dict when output is a plain string.
msg_content = ToolMessage(
content=content_str,
tool_call_id=tool_call_id or "",
name=tool_name,
status=status,
).model_dump()
# Trace event (always)
self._put(
event_type="tool_end",
category="trace",
content=content_str,
metadata={
"tool_name": tool_name,
"tool_call_id": tool_call_id,
"status": status,
},
)
# Message event: tool_result (checkpoint-aligned model_dump format)
self._put(
event_type="tool_result",
category="message",
content=msg_content,
metadata={"tool_name": tool_name, "status": status},
)
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
from langchain_core.messages import ToolMessage
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
# Trace event
self._put(
event_type="tool_error",
category="trace",
content=str(error),
metadata={
"tool_name": tool_name,
"tool_call_id": tool_call_id,
},
)
# Message event: tool_result with error status (checkpoint-aligned)
msg_content = ToolMessage(
content=str(error),
tool_call_id=tool_call_id or "",
name=tool_name,
status="error",
).model_dump()
self._put(
event_type="tool_result",
category="message",
content=msg_content,
metadata={"tool_name": tool_name, "status": "error"},
)
# -- Custom event callback --
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
from deerflow.runtime.serialization import serialize_lc_object
if name == "summarization":
data_dict = data if isinstance(data, dict) else {}
self._put(
event_type="summarization",
category="trace",
content=data_dict.get("summary", ""),
metadata={
"replaced_message_ids": data_dict.get("replaced_message_ids", []),
"replaced_count": data_dict.get("replaced_count", 0),
},
)
self._put(
event_type="middleware:summarize",
category="middleware",
content={"role": "system", "content": data_dict.get("summary", "")},
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
)
else:
event_data = serialize_lc_object(data) if not isinstance(data, dict) else data
self._put(
event_type=name,
category="trace",
metadata=event_data if isinstance(event_data, dict) else {"data": event_data},
)
# -- Internal methods --
def _put(self, *, event_type: str, category: str, content: str | dict = "", metadata: dict | None = None) -> None:
self._buffer.append(
{
"thread_id": self.thread_id,
"run_id": self.run_id,
"event_type": event_type,
"category": category,
"content": content,
"metadata": metadata or {},
"created_at": datetime.now(UTC).isoformat(),
}
)
if len(self._buffer) >= self._flush_threshold:
self._flush_sync()
def _flush_sync(self) -> None:
"""Best-effort flush of buffer to RunEventStore.
BaseCallbackHandler methods are synchronous. If an event loop is
running we schedule an async ``put_batch``; otherwise the events
stay in the buffer and are flushed later by the async ``flush()``
call in the worker's ``finally`` block.
"""
if not self._buffer:
return
# Skip if a flush is already in flight — avoids concurrent writes
# to the same SQLite file from multiple fire-and-forget tasks.
if self._pending_flush_tasks:
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# No event loop — keep events in buffer for later async flush.
return
batch = self._buffer.copy()
self._buffer.clear()
task = loop.create_task(self._flush_async(batch))
self._pending_flush_tasks.add(task)
task.add_done_callback(self._on_flush_done)
async def _flush_async(self, batch: list[dict]) -> None:
try:
await self._store.put_batch(batch)
except Exception:
logger.warning(
"Failed to flush %d events for run %s — returning to buffer",
len(batch),
self.run_id,
exc_info=True,
)
# Return failed events to buffer for retry on next flush
self._buffer = batch + self._buffer
def _on_flush_done(self, task: asyncio.Task) -> None:
self._pending_flush_tasks.discard(task)
if task.cancelled():
return
exc = task.exception()
if exc:
logger.warning("Journal flush task failed: %s", exc)
def _identify_caller(self, kwargs: dict) -> str:
for tag in kwargs.get("tags") or []:
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
return tag
# Default to lead_agent: the main agent graph does not inject
# callback tags, while subagents and middleware explicitly tag
# themselves.
return "lead_agent"
# -- Public methods (called by worker) --
def set_first_human_message(self, content: str) -> None:
"""Record the first human message for convenience fields."""
self._first_human_msg = content[:2000] if content else None
def record_middleware(self, tag: str, *, name: str, hook: str, action: str, changes: dict) -> None:
"""Record a middleware state-change event.
Called by middleware implementations when they perform a meaningful
state change (e.g., title generation, summarization, HITL approval).
Pure-observation middleware should not call this.
Args:
tag: Short identifier for the middleware (e.g., "title", "summarize",
"guardrail"). Used to form event_type="middleware:{tag}".
name: Full middleware class name.
hook: Lifecycle hook that triggered the action (e.g., "after_model").
action: Specific action performed (e.g., "generate_title").
changes: Dict describing the state changes made.
"""
self._put(
event_type=f"middleware:{tag}",
category="middleware",
content={"name": name, "hook": hook, "action": action, "changes": changes},
)
async def flush(self) -> None:
"""Force flush remaining buffer. Called in worker's finally block."""
if self._pending_flush_tasks:
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
while self._buffer:
batch = self._buffer[: self._flush_threshold]
del self._buffer[: self._flush_threshold]
try:
await self._store.put_batch(batch)
except Exception:
self._buffer = batch + self._buffer
raise
def get_completion_data(self) -> dict:
"""Return accumulated token and message data for run completion."""
return {
"total_input_tokens": self._total_input_tokens,
"total_output_tokens": self._total_output_tokens,
"total_tokens": self._total_tokens,
"llm_call_count": self._llm_call_count,
"lead_agent_tokens": self._lead_agent_tokens,
"subagent_tokens": self._subagent_tokens,
"middleware_tokens": self._middleware_tokens,
"message_count": self._msg_count,
"last_ai_message": self._last_ai_msg,
"first_human_message": self._first_human_msg,
}
@@ -2,11 +2,12 @@
from .manager import ConflictError, RunManager, RunRecord, UnsupportedStrategyError
from .schemas import DisconnectMode, RunStatus
from .worker import run_agent
from .worker import RunContext, run_agent
__all__ = [
"ConflictError",
"DisconnectMode",
"RunContext",
"RunManager",
"RunRecord",
"RunStatus",
@@ -1,4 +1,4 @@
"""In-memory run registry."""
"""In-memory run registry with optional persistent RunStore backing."""
from __future__ import annotations
@@ -7,9 +7,13 @@ import logging
import uuid
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import TYPE_CHECKING
from .schemas import DisconnectMode, RunStatus
if TYPE_CHECKING:
from deerflow.runtime.runs.store.base import RunStore
logger = logging.getLogger(__name__)
@@ -38,11 +42,44 @@ class RunRecord:
class RunManager:
"""In-memory run registry. All mutations are protected by an asyncio lock."""
"""In-memory run registry with optional persistent RunStore backing.
def __init__(self) -> None:
All mutations are protected by an asyncio lock. When a ``store`` is
provided, serializable metadata is also persisted to the store so
that run history survives process restarts.
"""
def __init__(self, store: RunStore | None = None) -> None:
self._runs: dict[str, RunRecord] = {}
self._lock = asyncio.Lock()
self._store = store
async def _persist_to_store(self, record: RunRecord, *, follow_up_to_run_id: str | None = None) -> None:
"""Best-effort persist run record to backing store."""
if self._store is None:
return
try:
await self._store.put(
record.run_id,
thread_id=record.thread_id,
assistant_id=record.assistant_id,
status=record.status.value,
multitask_strategy=record.multitask_strategy,
metadata=record.metadata or {},
kwargs=record.kwargs or {},
created_at=record.created_at,
follow_up_to_run_id=follow_up_to_run_id,
)
except Exception:
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
async def update_run_completion(self, run_id: str, **kwargs) -> None:
"""Persist token usage and completion data to the backing store."""
if self._store is not None:
try:
await self._store.update_run_completion(run_id, **kwargs)
except Exception:
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
async def create(
self,
@@ -53,6 +90,7 @@ class RunManager:
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
follow_up_to_run_id: str | None = None,
) -> RunRecord:
"""Create a new pending run and register it."""
run_id = str(uuid.uuid4())
@@ -71,6 +109,7 @@ class RunManager:
)
async with self._lock:
self._runs[run_id] = record
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record
@@ -96,6 +135,11 @@ class RunManager:
record.updated_at = _now_iso()
if error is not None:
record.error = error
if self._store is not None:
try:
await self._store.update_status(run_id, status.value, error=error)
except Exception:
logger.warning("Failed to persist status update for run %s", run_id, exc_info=True)
logger.info("Run %s -> %s", run_id, status.value)
async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool:
@@ -132,6 +176,7 @@ class RunManager:
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
follow_up_to_run_id: str | None = None,
) -> RunRecord:
"""Atomically check for inflight runs and create a new one.
@@ -185,6 +230,7 @@ class RunManager:
)
self._runs[run_id] = record
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record
@@ -0,0 +1,4 @@
from deerflow.runtime.runs.store.base import RunStore
from deerflow.runtime.runs.store.memory import MemoryRunStore
__all__ = ["MemoryRunStore", "RunStore"]
@@ -0,0 +1,96 @@
"""Abstract interface for run metadata storage.
RunManager depends on this interface. Implementations:
- MemoryRunStore: in-memory dict (development, tests)
- Future: RunRepository backed by SQLAlchemy ORM
All methods accept an optional user_id for user isolation.
When user_id is None, no user filtering is applied (single-user mode).
"""
from __future__ import annotations
import abc
from typing import Any
class RunStore(abc.ABC):
@abc.abstractmethod
async def put(
self,
run_id: str,
*,
thread_id: str,
assistant_id: str | None = None,
user_id: str | None = None,
status: str = "pending",
multitask_strategy: str = "reject",
metadata: dict[str, Any] | None = None,
kwargs: dict[str, Any] | None = None,
error: str | None = None,
created_at: str | None = None,
follow_up_to_run_id: str | None = None,
) -> None:
pass
@abc.abstractmethod
async def get(self, run_id: str) -> dict[str, Any] | None:
pass
@abc.abstractmethod
async def list_by_thread(
self,
thread_id: str,
*,
user_id: str | None = None,
limit: int = 100,
) -> list[dict[str, Any]]:
pass
@abc.abstractmethod
async def update_status(
self,
run_id: str,
status: str,
*,
error: str | None = None,
) -> None:
pass
@abc.abstractmethod
async def delete(self, run_id: str) -> None:
pass
@abc.abstractmethod
async def update_run_completion(
self,
run_id: str,
*,
status: str,
total_input_tokens: int = 0,
total_output_tokens: int = 0,
total_tokens: int = 0,
llm_call_count: int = 0,
lead_agent_tokens: int = 0,
subagent_tokens: int = 0,
middleware_tokens: int = 0,
message_count: int = 0,
last_ai_message: str | None = None,
first_human_message: str | None = None,
error: str | None = None,
) -> None:
pass
@abc.abstractmethod
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
pass
@abc.abstractmethod
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
"""Aggregate token usage for completed runs in a thread.
Returns a dict with keys: total_tokens, total_input_tokens,
total_output_tokens, total_runs, by_model (model_name → {tokens, runs}),
by_caller ({lead_agent, subagent, middleware}).
"""
pass
@@ -0,0 +1,100 @@
"""In-memory RunStore. Used when database.backend=memory (default) and in tests.
Equivalent to the original RunManager._runs dict behavior.
"""
from __future__ import annotations
from datetime import UTC, datetime
from typing import Any
from deerflow.runtime.runs.store.base import RunStore
class MemoryRunStore(RunStore):
def __init__(self) -> None:
self._runs: dict[str, dict[str, Any]] = {}
async def put(
self,
run_id,
*,
thread_id,
assistant_id=None,
user_id=None,
status="pending",
multitask_strategy="reject",
metadata=None,
kwargs=None,
error=None,
created_at=None,
follow_up_to_run_id=None,
):
now = datetime.now(UTC).isoformat()
self._runs[run_id] = {
"run_id": run_id,
"thread_id": thread_id,
"assistant_id": assistant_id,
"user_id": user_id,
"status": status,
"multitask_strategy": multitask_strategy,
"metadata": metadata or {},
"kwargs": kwargs or {},
"error": error,
"follow_up_to_run_id": follow_up_to_run_id,
"created_at": created_at or now,
"updated_at": now,
}
async def get(self, run_id):
return self._runs.get(run_id)
async def list_by_thread(self, thread_id, *, user_id=None, limit=100):
results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (user_id is None or r.get("user_id") == user_id)]
results.sort(key=lambda r: r["created_at"], reverse=True)
return results[:limit]
async def update_status(self, run_id, status, *, error=None):
if run_id in self._runs:
self._runs[run_id]["status"] = status
if error is not None:
self._runs[run_id]["error"] = error
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def delete(self, run_id):
self._runs.pop(run_id, None)
async def update_run_completion(self, run_id, *, status, **kwargs):
if run_id in self._runs:
self._runs[run_id]["status"] = status
for key, value in kwargs.items():
if value is not None:
self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def list_pending(self, *, before=None):
now = before or datetime.now(UTC).isoformat()
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]
results.sort(key=lambda r: r["created_at"])
return results
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in ("success", "error")]
by_model: dict[str, dict] = {}
for r in completed:
model = r.get("model_name") or "unknown"
entry = by_model.setdefault(model, {"tokens": 0, "runs": 0})
entry["tokens"] += r.get("total_tokens", 0)
entry["runs"] += 1
return {
"total_tokens": sum(r.get("total_tokens", 0) for r in completed),
"total_input_tokens": sum(r.get("total_input_tokens", 0) for r in completed),
"total_output_tokens": sum(r.get("total_output_tokens", 0) for r in completed),
"total_runs": len(completed),
"by_model": by_model,
"by_caller": {
"lead_agent": sum(r.get("lead_agent_tokens", 0) for r in completed),
"subagent": sum(r.get("subagent_tokens", 0) for r in completed),
"middleware": sum(r.get("middleware_tokens", 0) for r in completed),
},
}
@@ -19,8 +19,14 @@ import asyncio
import copy
import inspect
import logging
from typing import Any, Literal
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
from langchain_core.messages import HumanMessage
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.runtime.serialization import serialize
from deerflow.runtime.stream_bridge import StreamBridge
@@ -33,13 +39,30 @@ logger = logging.getLogger(__name__)
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}
@dataclass(frozen=True)
class RunContext:
"""Infrastructure dependencies for a single agent run.
Groups checkpointer, store, and persistence-related singletons so that
``run_agent`` (and any future callers) receive one object instead of a
growing list of keyword arguments.
"""
checkpointer: Any
store: Any | None = field(default=None)
event_store: Any | None = field(default=None)
run_events_config: Any | None = field(default=None)
thread_store: Any | None = field(default=None)
follow_up_to_run_id: str | None = field(default=None)
app_config: AppConfig | None = field(default=None)
async def run_agent(
bridge: StreamBridge,
run_manager: RunManager,
record: RunRecord,
*,
checkpointer: Any,
store: Any | None = None,
ctx: RunContext,
agent_factory: Any,
graph_input: dict,
config: dict,
@@ -50,6 +73,14 @@ async def run_agent(
) -> None:
"""Execute an agent in the background, publishing events to *bridge*."""
# Unpack infrastructure dependencies from RunContext.
checkpointer = ctx.checkpointer
store = ctx.store
event_store = ctx.event_store
run_events_config = ctx.run_events_config
thread_store = ctx.thread_store
follow_up_to_run_id = ctx.follow_up_to_run_id
run_id = record.run_id
thread_id = record.thread_id
requested_modes: set[str] = set(stream_modes or ["values"])
@@ -57,6 +88,10 @@ async def run_agent(
pre_run_snapshot: dict[str, Any] | None = None
snapshot_capture_failed = False
journal = None
journal = None
# Track whether "events" was requested but skipped
if "events" in requested_modes:
logger.info(
@@ -65,6 +100,38 @@ async def run_agent(
)
try:
# Initialize RunJournal + write human_message event.
# These are inside the try block so any exception (e.g. a DB
# error writing the event) flows through the except/finally
# path that publishes an "end" event to the SSE bridge —
# otherwise a failure here would leave the stream hanging
# with no terminator.
if event_store is not None:
from deerflow.runtime.journal import RunJournal
journal = RunJournal(
run_id=run_id,
thread_id=thread_id,
event_store=event_store,
track_token_usage=getattr(run_events_config, "track_token_usage", True),
)
human_msg = _extract_human_message(graph_input)
if human_msg is not None:
msg_metadata = {}
if follow_up_to_run_id:
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
await event_store.put(
thread_id=thread_id,
run_id=run_id,
event_type="human_message",
category="message",
content=human_msg.model_dump(),
metadata=msg_metadata or None,
)
content = human_msg.content
journal.set_first_human_message(content if isinstance(content, str) else str(content))
# 1. Mark running
await run_manager.set_status(run_id, RunStatus.running)
@@ -98,17 +165,21 @@ async def run_agent(
# 3. Build the agent
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
# Inject runtime context so middlewares can access thread_id
# (langgraph-cli does this automatically; we must do it manually)
runtime = Runtime(context={"thread_id": thread_id}, store=store)
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
# prefers it over ``configurable`` for thread-level data), make
# sure ``thread_id`` is available there too.
if "context" in config and isinstance(config["context"], dict):
config["context"].setdefault("thread_id", thread_id)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
# Construct typed context for the agent run.
# LangGraph's astream(context=...) injects this into Runtime.context
# so middleware/tools can access it via resolve_context().
if ctx.app_config is None:
raise RuntimeError("RunContext.app_config is required — Gateway must populate it via get_run_context")
deer_flow_context = DeerFlowContext(
app_config=ctx.app_config,
thread_id=thread_id,
)
# Inject RunJournal as a LangChain callback handler.
# on_llm_end captures token usage; on_chain_start/end captures lifecycle.
if journal is not None:
config.setdefault("callbacks", []).append(journal)
runnable_config = RunnableConfig(**config)
agent = agent_factory(config=runnable_config)
@@ -155,7 +226,7 @@ async def run_agent(
if len(lg_modes) == 1 and not stream_subgraphs:
# Single mode, no subgraphs: astream yields raw chunks
single_mode = lg_modes[0]
async for chunk in agent.astream(graph_input, config=runnable_config, stream_mode=single_mode):
async for chunk in agent.astream(graph_input, config=runnable_config, context=deer_flow_context, stream_mode=single_mode):
if record.abort_event.is_set():
logger.info("Run %s abort requested — stopping", run_id)
break
@@ -166,6 +237,7 @@ async def run_agent(
async for item in agent.astream(
graph_input,
config=runnable_config,
context=deer_flow_context,
stream_mode=lg_modes,
subgraphs=stream_subgraphs,
):
@@ -236,6 +308,41 @@ async def run_agent(
)
finally:
# Flush any buffered journal events and persist completion data
if journal is not None:
try:
await journal.flush()
except Exception:
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
try:
# Persist token usage + convenience fields to RunStore
completion = journal.get_completion_data()
await run_manager.update_run_completion(run_id, status=record.status.value, **completion)
except Exception:
logger.warning("Failed to persist run completion for %s (non-fatal)", run_id, exc_info=True)
# Sync title from checkpoint to threads_meta.display_name
if checkpointer is not None and thread_store is not None:
try:
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
if ckpt_tuple is not None:
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
title = ckpt.get("channel_values", {}).get("title")
if title:
await thread_store.update_display_name(thread_id, title)
except Exception:
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
# Update threads_meta status based on run outcome
if thread_store is not None:
try:
final_status = "idle" if record.status == RunStatus.success else record.status.value
await thread_store.update_status(thread_id, final_status)
except Exception:
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
await bridge.publish_end(run_id)
asyncio.create_task(bridge.cleanup(run_id, delay=60))
@@ -355,6 +462,31 @@ def _lg_mode_to_sse_event(mode: str) -> str:
return mode
def _extract_human_message(graph_input: dict) -> HumanMessage | None:
"""Extract or construct a HumanMessage from graph_input for event recording.
Returns a LangChain HumanMessage so callers can use .model_dump() to get
the checkpoint-aligned serialization format.
"""
from langchain_core.messages import HumanMessage
messages = graph_input.get("messages")
if not messages:
return None
last = messages[-1] if isinstance(messages, list) else messages
if isinstance(last, HumanMessage):
return last
if isinstance(last, str):
return HumanMessage(content=last) if last else None
if hasattr(last, "content"):
content = last.content
return HumanMessage(content=content)
if isinstance(last, dict):
content = last.get("content", "")
return HumanMessage(content=content) if content else None
return None
def _unpack_stream_item(
item: Any,
lg_modes: list[str],
@@ -23,7 +23,7 @@ from collections.abc import AsyncIterator
from langgraph.store.base import BaseStore
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.runtime.store.provider import POSTGRES_CONN_REQUIRED, POSTGRES_STORE_INSTALL, SQLITE_STORE_INSTALL, ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@@ -86,28 +86,26 @@ async def _async_store(config) -> AsyncIterator[BaseStore]:
@contextlib.asynccontextmanager
async def make_store() -> AsyncIterator[BaseStore]:
async def make_store(app_config: AppConfig) -> AsyncIterator[BaseStore]:
"""Async context manager that yields a Store whose backend matches the
configured checkpointer.
Reads from the same ``checkpointer`` section of *config.yaml* used by
:func:`deerflow.agents.checkpointer.async_provider.make_checkpointer` so
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer` so
that both singletons always use the same persistence technology::
async with make_store() as store:
async with make_store(app_config) as store:
app.state.store = store
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
``checkpointer`` section is configured (emits a WARNING in that case).
"""
config = get_app_config()
if config.checkpointer is None:
if app_config.checkpointer is None:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
yield InMemoryStore()
return
async with _async_store(config.checkpointer) as store:
async with _async_store(app_config.checkpointer) as store:
yield store
@@ -26,7 +26,7 @@ from collections.abc import Iterator
from langgraph.store.base import BaseStore
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@@ -100,7 +100,7 @@ _store: BaseStore | None = None
_store_ctx = None # open context manager keeping the connection alive
def get_store() -> BaseStore:
def get_store(app_config: AppConfig) -> BaseStore:
"""Return the global sync Store singleton, creating it on first call.
Returns an :class:`~langgraph.store.memory.InMemoryStore` when no
@@ -115,19 +115,10 @@ def get_store() -> BaseStore:
if _store is not None:
return _store
# Lazily load app config, mirroring the checkpointer singleton pattern so
# that tests that set the global checkpointer config explicitly remain isolated.
from deerflow.config.app_config import _app_config
from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config()
if config is None and _app_config is None:
try:
get_app_config()
except FileNotFoundError:
pass
config = get_checkpointer_config()
# See matching comment in checkpointer/provider.py: a missing config.yaml
# is a deployment error, not a cue to silently pick InMemoryStore. Only
# the explicit "no checkpointer section" path falls through to memory.
config = app_config.checkpointer
if config is None:
from langgraph.store.memory import InMemoryStore
@@ -163,26 +154,25 @@ def reset_store() -> None:
@contextlib.contextmanager
def store_context() -> Iterator[BaseStore]:
def store_context(app_config: AppConfig) -> Iterator[BaseStore]:
"""Sync context manager that yields a Store and cleans up on exit.
Unlike :func:`get_store`, this does **not** cache the instance — each
``with`` block creates and destroys its own connection. Use it in CLI
scripts or tests where you want deterministic cleanup::
with store_context() as store:
with store_context(app_config) as store:
store.put(("threads",), thread_id, {...})
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
if config.checkpointer is None:
if app_config.checkpointer is None:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
yield InMemoryStore()
return
with _sync_store_cm(config.checkpointer) as store:
with _sync_store_cm(app_config.checkpointer) as store:
yield store
@@ -1,7 +1,7 @@
"""Async stream bridge factory.
Provides an **async context manager** aligned with
:func:`deerflow.agents.checkpointer.async_provider.make_checkpointer`.
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer`.
Usage (e.g. FastAPI lifespan)::
@@ -17,7 +17,7 @@ import contextlib
import logging
from collections.abc import AsyncIterator
from deerflow.config.stream_bridge_config import get_stream_bridge_config
from deerflow.config.app_config import AppConfig
from .base import StreamBridge
@@ -25,14 +25,13 @@ logger = logging.getLogger(__name__)
@contextlib.asynccontextmanager
async def make_stream_bridge(config=None) -> AsyncIterator[StreamBridge]:
async def make_stream_bridge(app_config: AppConfig) -> AsyncIterator[StreamBridge]:
"""Async context manager that yields a :class:`StreamBridge`.
Falls back to :class:`MemoryStreamBridge` when no configuration is
provided and nothing is set globally.
Falls back to :class:`MemoryStreamBridge` when no ``stream_bridge``
section is configured.
"""
if config is None:
config = get_stream_bridge_config()
config = app_config.stream_bridge
if config is None or config.type == "memory":
from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge
@@ -0,0 +1,167 @@
"""Request-scoped user context for user-based authorization.
This module holds a :class:`~contextvars.ContextVar` that the gateway's
auth middleware sets after a successful authentication. Repository
methods read the contextvar via a sentinel default parameter, letting
routers stay free of ``user_id`` boilerplate.
Three-state semantics for the repository ``user_id`` parameter (the
consumer side of this module lives in ``deerflow.persistence.*``):
- ``_AUTO`` (module-private sentinel, default): read from contextvar;
raise :class:`RuntimeError` if unset.
- Explicit ``str``: use the provided value, overriding contextvar.
- Explicit ``None``: no WHERE clause — used only by migration scripts
and admin CLIs that intentionally bypass isolation.
Dependency direction
--------------------
``persistence`` (lower layer) reads from this module; ``gateway.auth``
(higher layer) writes to it. ``CurrentUser`` is defined here as a
:class:`typing.Protocol` so that ``persistence`` never needs to import
the concrete ``User`` class from ``gateway.auth.models``. Any object
with an ``.id: str`` attribute structurally satisfies the protocol.
Asyncio semantics
-----------------
``ContextVar`` is task-local under asyncio, not thread-local. Each
FastAPI request runs in its own task, so the context is naturally
isolated. ``asyncio.create_task`` and ``asyncio.to_thread`` inherit the
parent task's context, which is typically the intended behaviour; if
a background task must *not* see the foreground user, wrap it with
``contextvars.copy_context()`` to get a clean copy.
"""
from __future__ import annotations
from contextvars import ContextVar, Token
from typing import Final, Protocol, runtime_checkable
@runtime_checkable
class CurrentUser(Protocol):
"""Structural type for the current authenticated user.
Any object with an ``.id: str`` attribute satisfies this protocol.
Concrete implementations live in ``app.gateway.auth.models.User``.
"""
id: str
_current_user: Final[ContextVar[CurrentUser | None]] = ContextVar("deerflow_current_user", default=None)
def set_current_user(user: CurrentUser) -> Token[CurrentUser | None]:
"""Set the current user for this async task.
Returns a reset token that should be passed to
:func:`reset_current_user` in a ``finally`` block to restore the
previous context.
"""
return _current_user.set(user)
def reset_current_user(token: Token[CurrentUser | None]) -> None:
"""Restore the context to the state captured by ``token``."""
_current_user.reset(token)
def get_current_user() -> CurrentUser | None:
"""Return the current user, or ``None`` if unset.
Safe to call in any context. Used by code paths that can proceed
without a user (e.g. migration scripts, public endpoints).
"""
return _current_user.get()
def require_current_user() -> CurrentUser:
"""Return the current user, or raise :class:`RuntimeError`.
Used by repository code that must not be called outside a
request-authenticated context. The error message is phrased so
that a caller debugging a stack trace can locate the offending
code path.
"""
user = _current_user.get()
if user is None:
raise RuntimeError("repository accessed without user context")
return user
# ---------------------------------------------------------------------------
# Effective user_id helpers (filesystem isolation)
# ---------------------------------------------------------------------------
DEFAULT_USER_ID: Final[str] = "default"
def get_effective_user_id() -> str:
"""Return the current user's id as a string, or DEFAULT_USER_ID if unset.
Unlike :func:`require_current_user` this never raises — it is designed
for filesystem-path resolution where a valid user bucket is always needed.
"""
user = _current_user.get()
if user is None:
return DEFAULT_USER_ID
return str(user.id)
# ---------------------------------------------------------------------------
# Sentinel-based user_id resolution
# ---------------------------------------------------------------------------
#
# Repository methods accept a ``user_id`` keyword-only argument that
# defaults to ``AUTO``. The three possible values drive distinct
# behaviours; see the docstring on :func:`resolve_user_id`.
class _AutoSentinel:
"""Singleton marker meaning 'resolve user_id from contextvar'."""
_instance: _AutoSentinel | None = None
def __new__(cls) -> _AutoSentinel:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __repr__(self) -> str:
return "<AUTO>"
AUTO: Final[_AutoSentinel] = _AutoSentinel()
def resolve_user_id(
value: str | None | _AutoSentinel,
*,
method_name: str = "repository method",
) -> str | None:
"""Resolve the user_id parameter passed to a repository method.
Three-state semantics:
- :data:`AUTO` (default): read from contextvar; raise
:class:`RuntimeError` if no user is in context. This is the
common case for request-scoped calls.
- Explicit ``str``: use the provided id verbatim, overriding any
contextvar value. Useful for tests and admin-override flows.
- Explicit ``None``: no filter — the repository should skip the
user_id WHERE clause entirely. Reserved for migration scripts
and CLI tools that intentionally bypass isolation.
"""
if isinstance(value, _AutoSentinel):
user = _current_user.get()
if user is None:
raise RuntimeError(f"{method_name} called with user_id=AUTO but no user context is set; pass an explicit user_id, set the contextvar via auth middleware, or opt out with user_id=None for migration/CLI paths.")
# Coerce to ``str`` at the boundary: ``User.id`` is typed as
# ``UUID`` for the API surface, but the persistence layer
# stores ``user_id`` as ``String(64)`` and aiosqlite cannot
# bind a raw UUID object to a VARCHAR column ("type 'UUID' is
# not supported"). Honour the documented return type here
# rather than ripple a type change through every caller.
return str(user.id)
return value
@@ -1,10 +1,14 @@
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from deerflow.sandbox.local.local_sandbox import LocalSandbox, PathMapping
from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import SandboxProvider
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
_singleton: LocalSandbox | None = None
@@ -13,8 +17,9 @@ _singleton: LocalSandbox | None = None
class LocalSandboxProvider(SandboxProvider):
uses_thread_data_mounts = True
def __init__(self):
def __init__(self, app_config: "AppConfig"):
"""Initialize the local sandbox provider with path mappings."""
self._app_config = app_config
self._path_mappings = self._setup_path_mappings()
def _setup_path_mappings(self) -> list[PathMapping]:
@@ -31,9 +36,7 @@ class LocalSandboxProvider(SandboxProvider):
# Map skills container path to local skills directory
try:
from deerflow.config import get_app_config
config = get_app_config()
config = self._app_config
skills_path = config.skills.get_skills_path()
container_path = config.skills.container_path
@@ -6,6 +6,7 @@ from langchain.agents.middleware import AgentMiddleware
from langgraph.runtime import Runtime
from deerflow.agents.thread_state import SandboxState, ThreadDataState
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.sandbox import get_sandbox_provider
logger = logging.getLogger(__name__)
@@ -42,41 +43,35 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
super().__init__()
self._lazy_init = lazy_init
def _acquire_sandbox(self, thread_id: str) -> str:
provider = get_sandbox_provider()
def _acquire_sandbox(self, thread_id: str, runtime: Runtime[DeerFlowContext]) -> str:
provider = get_sandbox_provider(runtime.context.app_config)
sandbox_id = provider.acquire(thread_id)
logger.info(f"Acquiring sandbox {sandbox_id}")
return sandbox_id
@override
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
# Skip acquisition if lazy_init is enabled
if self._lazy_init:
return super().before_agent(state, runtime)
# Eager initialization (original behavior)
if "sandbox" not in state or state["sandbox"] is None:
thread_id = (runtime.context or {}).get("thread_id")
if thread_id is None:
thread_id = runtime.context.thread_id
if not thread_id:
return super().before_agent(state, runtime)
sandbox_id = self._acquire_sandbox(thread_id)
sandbox_id = self._acquire_sandbox(thread_id, runtime)
logger.info(f"Assigned sandbox {sandbox_id} to thread {thread_id}")
return {"sandbox": {"sandbox_id": sandbox_id}}
return super().before_agent(state, runtime)
@override
def after_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
def after_agent(self, state: SandboxMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
sandbox = state.get("sandbox")
if sandbox is not None:
sandbox_id = sandbox["sandbox_id"]
logger.info(f"Releasing sandbox {sandbox_id}")
get_sandbox_provider().release(sandbox_id)
return None
if (runtime.context or {}).get("sandbox_id") is not None:
sandbox_id = runtime.context.get("sandbox_id")
logger.info(f"Releasing sandbox {sandbox_id} from context")
get_sandbox_provider().release(sandbox_id)
get_sandbox_provider(runtime.context.app_config).release(sandbox_id)
return None
# No sandbox to release
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_class
from deerflow.sandbox.sandbox import Sandbox
@@ -41,23 +41,38 @@ class SandboxProvider(ABC):
_default_sandbox_provider: SandboxProvider | None = None
def get_sandbox_provider(**kwargs) -> SandboxProvider:
def get_sandbox_provider(app_config: AppConfig, **kwargs) -> SandboxProvider:
"""Get the sandbox provider singleton.
Returns a cached singleton instance. Use `reset_sandbox_provider()` to clear
the cache, or `shutdown_sandbox_provider()` to properly shutdown and clear.
Args:
app_config: Application config used the first time the singleton is built.
Ignored on subsequent calls — the cached instance is returned
regardless of the config passed.
Returns:
A sandbox provider instance.
"""
global _default_sandbox_provider
if _default_sandbox_provider is None:
config = get_app_config()
cls = resolve_class(config.sandbox.use, SandboxProvider)
_default_sandbox_provider = cls(**kwargs)
cls = resolve_class(app_config.sandbox.use, SandboxProvider)
_default_sandbox_provider = cls(app_config=app_config, **kwargs) if _accepts_app_config(cls) else cls(**kwargs)
return _default_sandbox_provider
def _accepts_app_config(cls: type) -> bool:
"""Return True when the provider's __init__ accepts an ``app_config`` kwarg."""
import inspect
try:
sig = inspect.signature(cls.__init__)
except (TypeError, ValueError):
return False
return "app_config" in sig.parameters
def reset_sandbox_provider() -> None:
"""Reset the sandbox provider singleton.
@@ -1,6 +1,6 @@
"""Security helpers for sandbox capability gating."""
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
_LOCAL_SANDBOX_PROVIDER_MARKERS = (
"deerflow.sandbox.local:LocalSandboxProvider",
@@ -20,11 +20,8 @@ LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE = (
)
def uses_local_sandbox_provider(config=None) -> bool:
def uses_local_sandbox_provider(config: AppConfig) -> bool:
"""Return True when the active sandbox provider is the host-local provider."""
if config is None:
config = get_app_config()
sandbox_cfg = getattr(config, "sandbox", None)
sandbox_use = getattr(sandbox_cfg, "use", "")
if sandbox_use in _LOCAL_SANDBOX_PROVIDER_MARKERS:
@@ -32,11 +29,8 @@ def uses_local_sandbox_provider(config=None) -> bool:
return sandbox_use.endswith(":LocalSandboxProvider") and "deerflow.sandbox.local" in sandbox_use
def is_host_bash_allowed(config=None) -> bool:
def is_host_bash_allowed(config: AppConfig) -> bool:
"""Return whether host bash execution is explicitly allowed."""
if config is None:
config = get_app_config()
sandbox_cfg = getattr(config, "sandbox", None)
if sandbox_cfg is None:
return False
+161 -198
View File
@@ -7,7 +7,8 @@ from langchain.tools import ToolRuntime, tool
from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadDataState, ThreadState
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import resolve_context
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
from deerflow.sandbox.exceptions import (
SandboxError,
@@ -39,62 +40,43 @@ _DEFAULT_GREP_MAX_RESULTS = 100
_MAX_GREP_MAX_RESULTS = 500
def _get_skills_container_path() -> str:
"""Get the skills container path from config, with fallback to default.
Result is cached after the first successful config load. If config loading
fails the default is returned *without* caching so that a later call can
pick up the real value once the config is available.
"""
cached = getattr(_get_skills_container_path, "_cached", None)
if cached is not None:
return cached
try:
from deerflow.config import get_app_config
value = get_app_config().skills.container_path
_get_skills_container_path._cached = value # type: ignore[attr-defined]
return value
except Exception:
def _get_skills_container_path(app_config: AppConfig) -> str:
"""Get the skills container path from config, with fallback to default."""
skills_cfg = getattr(app_config, "skills", None)
if skills_cfg is None:
return _DEFAULT_SKILLS_CONTAINER_PATH
return skills_cfg.container_path
def _get_skills_host_path() -> str | None:
def _get_skills_host_path(app_config: AppConfig) -> str | None:
"""Get the skills host filesystem path from config.
Returns None if the skills directory does not exist or config cannot be
loaded. Only successful lookups are cached; failures are retried on the
next call so that a transiently unavailable skills directory does not
permanently disable skills access.
Returns None if the skills directory does not exist or is not configured.
"""
cached = getattr(_get_skills_host_path, "_cached", None)
if cached is not None:
return cached
skills_cfg = getattr(app_config, "skills", None)
if skills_cfg is None:
return None
try:
from deerflow.config import get_app_config
config = get_app_config()
skills_path = config.skills.get_skills_path()
if skills_path.exists():
value = str(skills_path)
_get_skills_host_path._cached = value # type: ignore[attr-defined]
return value
skills_path = skills_cfg.get_skills_path()
except Exception:
pass
return None
if skills_path.exists():
return str(skills_path)
return None
def _is_skills_path(path: str) -> bool:
def _is_skills_path(path: str, app_config: AppConfig) -> bool:
"""Check if a path is under the skills container path."""
skills_prefix = _get_skills_container_path()
skills_prefix = _get_skills_container_path(app_config)
return path == skills_prefix or path.startswith(f"{skills_prefix}/")
def _resolve_skills_path(path: str) -> str:
def _resolve_skills_path(path: str, app_config: AppConfig) -> str:
"""Resolve a virtual skills path to a host filesystem path.
Args:
path: Virtual skills path (e.g. /mnt/skills/public/bootstrap/SKILL.md)
app_config: Resolved application config.
Returns:
Resolved host path.
@@ -102,8 +84,8 @@ def _resolve_skills_path(path: str) -> str:
Raises:
FileNotFoundError: If skills directory is not configured or doesn't exist.
"""
skills_container = _get_skills_container_path()
skills_host = _get_skills_host_path()
skills_container = _get_skills_container_path(app_config)
skills_host = _get_skills_host_path(app_config)
if skills_host is None:
raise FileNotFoundError(f"Skills directory not available for path: {path}")
@@ -119,48 +101,31 @@ def _is_acp_workspace_path(path: str) -> bool:
return path == _ACP_WORKSPACE_VIRTUAL_PATH or path.startswith(f"{_ACP_WORKSPACE_VIRTUAL_PATH}/")
def _get_custom_mounts():
def _get_custom_mounts(app_config: AppConfig):
"""Get custom volume mounts from sandbox config.
Result is cached after the first successful config load. If config loading
fails an empty list is returned *without* caching so that a later call can
pick up the real value once the config is available.
Only includes mounts whose host_path exists, consistent with
``LocalSandboxProvider._setup_path_mappings()`` which also filters by
``host_path.exists()``.
"""
cached = getattr(_get_custom_mounts, "_cached", None)
if cached is not None:
return cached
try:
from pathlib import Path
from deerflow.config import get_app_config
config = get_app_config()
mounts = []
if config.sandbox and config.sandbox.mounts:
# Only include mounts whose host_path exists, consistent with
# LocalSandboxProvider._setup_path_mappings() which also filters
# by host_path.exists().
mounts = [m for m in config.sandbox.mounts if Path(m.host_path).exists()]
_get_custom_mounts._cached = mounts # type: ignore[attr-defined]
return mounts
except Exception:
# If config loading fails, return an empty list without caching so that
# a later call can retry once the config is available.
sandbox_cfg = getattr(app_config, "sandbox", None)
if sandbox_cfg is None or not sandbox_cfg.mounts:
return []
return [m for m in sandbox_cfg.mounts if Path(m.host_path).exists()]
def _is_custom_mount_path(path: str) -> bool:
def _is_custom_mount_path(path: str, app_config: AppConfig) -> bool:
"""Check if path is under a custom mount container_path."""
for mount in _get_custom_mounts():
for mount in _get_custom_mounts(app_config):
if path == mount.container_path or path.startswith(f"{mount.container_path}/"):
return True
return False
def _get_custom_mount_for_path(path: str):
def _get_custom_mount_for_path(path: str, app_config: AppConfig):
"""Get the mount config matching this path (longest prefix first)."""
best = None
for mount in _get_custom_mounts():
for mount in _get_custom_mounts(app_config):
if path == mount.container_path or path.startswith(f"{mount.container_path}/"):
if best is None or len(mount.container_path) > len(best.container_path):
best = mount
@@ -200,8 +165,9 @@ def _get_acp_workspace_host_path(thread_id: str | None = None) -> str | None:
if thread_id is not None:
try:
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
host_path = get_paths().acp_workspace_dir(thread_id)
host_path = get_paths().acp_workspace_dir(thread_id, user_id=get_effective_user_id())
if host_path.exists():
return str(host_path)
except Exception:
@@ -270,44 +236,40 @@ def _resolve_acp_workspace_path(path: str, thread_id: str | None = None) -> str:
return str(resolved_path)
def _get_mcp_allowed_paths() -> list[str]:
def _get_mcp_allowed_paths(app_config: AppConfig) -> list[str]:
"""Get the list of allowed paths from MCP config for file system server."""
allowed_paths = []
try:
from deerflow.config.extensions_config import get_extensions_config
allowed_paths: list[str] = []
extensions_config = getattr(app_config, "extensions", None)
if extensions_config is None:
return allowed_paths
extensions_config = get_extensions_config()
for _, server in extensions_config.mcp_servers.items():
if not server.enabled:
continue
for _, server in extensions_config.mcp_servers.items():
if not server.enabled:
continue
# Only check the filesystem server
args = server.args or []
# Check if args has server-filesystem package
has_filesystem = any("server-filesystem" in arg for arg in args)
if not has_filesystem:
continue
# Unpack the allowed file system paths in config
for arg in args:
if not arg.startswith("-") and arg.startswith("/"):
allowed_paths.append(arg.rstrip("/") + "/")
except Exception:
pass
# Only check the filesystem server
args = server.args or []
# Check if args has server-filesystem package
has_filesystem = any("server-filesystem" in arg for arg in args)
if not has_filesystem:
continue
# Unpack the allowed file system paths in config
for arg in args:
if not arg.startswith("-") and arg.startswith("/"):
allowed_paths.append(arg.rstrip("/") + "/")
return allowed_paths
def _get_tool_config_int(name: str, key: str, default: int) -> int:
def _get_tool_config_int(app_config: AppConfig, name: str, key: str, default: int) -> int:
try:
tool_config = get_app_config().get_tool_config(name)
if tool_config is not None and key in tool_config.model_extra:
value = tool_config.model_extra.get(key)
if isinstance(value, int):
return value
tool_config = app_config.get_tool_config(name)
except Exception:
pass
return default
if tool_config is not None and key in tool_config.model_extra:
value = tool_config.model_extra.get(key)
if isinstance(value, int):
return value
return default
@@ -317,23 +279,23 @@ def _clamp_max_results(value: int, *, default: int, upper_bound: int) -> int:
return min(value, upper_bound)
def _resolve_max_results(name: str, requested: int, *, default: int, upper_bound: int) -> int:
def _resolve_max_results(app_config: AppConfig, name: str, requested: int, *, default: int, upper_bound: int) -> int:
requested_max_results = _clamp_max_results(requested, default=default, upper_bound=upper_bound)
configured_max_results = _clamp_max_results(
_get_tool_config_int(name, "max_results", default),
_get_tool_config_int(app_config, name, "max_results", default),
default=default,
upper_bound=upper_bound,
)
return min(requested_max_results, configured_max_results)
def _resolve_local_read_path(path: str, thread_data: ThreadDataState) -> str:
validate_local_tool_path(path, thread_data, read_only=True)
if _is_skills_path(path):
return _resolve_skills_path(path)
def _resolve_local_read_path(path: str, thread_data: ThreadDataState, app_config: AppConfig) -> str:
validate_local_tool_path(path, thread_data, app_config, read_only=True)
if _is_skills_path(path, app_config):
return _resolve_skills_path(path, app_config)
if _is_acp_workspace_path(path):
return _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
return _resolve_and_validate_user_data_path(path, thread_data)
return _resolve_and_validate_user_data_path(path, thread_data, app_config)
def _format_glob_results(root_path: str, matches: list[str], truncated: bool) -> str:
@@ -379,7 +341,11 @@ def _join_path_preserving_style(base: str, relative: str) -> str:
return f"{stripped_base}{separator}{normalized_relative}"
def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str:
def _sanitize_error(
error: Exception,
runtime: "ToolRuntime[ContextT, ThreadState] | None" = None,
app_config: AppConfig | None = None,
) -> str:
"""Sanitize an error message to avoid leaking host filesystem paths.
In local-sandbox mode, resolved host paths in the error string are masked
@@ -388,8 +354,12 @@ def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadStat
"""
msg = f"{type(error).__name__}: {error}"
if runtime is not None and is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
msg = mask_local_paths_in_output(msg, thread_data)
if app_config is None:
ctx = getattr(runtime, "context", None)
app_config = getattr(ctx, "app_config", None)
if app_config is not None:
thread_data = get_thread_data(runtime)
msg = mask_local_paths_in_output(msg, thread_data, app_config)
return msg
@@ -459,7 +429,7 @@ def _thread_actual_to_virtual_mappings(thread_data: ThreadDataState) -> dict[str
return {actual: virtual for virtual, actual in _thread_virtual_to_actual_mappings(thread_data).items()}
def mask_local_paths_in_output(output: str, thread_data: ThreadDataState | None) -> str:
def mask_local_paths_in_output(output: str, thread_data: ThreadDataState | None, app_config: AppConfig) -> str:
"""Mask host absolute paths from local sandbox output using virtual paths.
Handles user-data paths (per-thread), skills paths, and ACP workspace paths (global).
@@ -467,8 +437,8 @@ def mask_local_paths_in_output(output: str, thread_data: ThreadDataState | None)
result = output
# Mask skills host paths
skills_host = _get_skills_host_path()
skills_container = _get_skills_container_path()
skills_host = _get_skills_host_path(app_config)
skills_container = _get_skills_container_path(app_config)
if skills_host:
raw_base = str(Path(skills_host))
resolved_base = str(Path(skills_host).resolve())
@@ -542,7 +512,13 @@ def _reject_path_traversal(path: str) -> None:
raise PermissionError("Access denied: path traversal detected")
def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *, read_only: bool = False) -> None:
def validate_local_tool_path(
path: str,
thread_data: ThreadDataState | None,
app_config: AppConfig,
*,
read_only: bool = False,
) -> None:
"""Validate that a virtual path is allowed for local-sandbox access.
This function is a security gate — it checks whether *path* may be
@@ -571,7 +547,7 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
_reject_path_traversal(path)
# Skills paths — read-only access only
if _is_skills_path(path):
if _is_skills_path(path, app_config):
if not read_only:
raise PermissionError(f"Write access to skills path is not allowed: {path}")
return
@@ -587,13 +563,13 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
return
# Custom mount paths — respect read_only config
if _is_custom_mount_path(path):
mount = _get_custom_mount_for_path(path)
if _is_custom_mount_path(path, app_config):
mount = _get_custom_mount_for_path(path, app_config)
if mount and mount.read_only and not read_only:
raise PermissionError(f"Write access to read-only mount is not allowed: {path}")
return
raise PermissionError(f"Only paths under {VIRTUAL_PATH_PREFIX}/, {_get_skills_container_path()}/, {_ACP_WORKSPACE_VIRTUAL_PATH}/, or configured mount paths are allowed")
raise PermissionError(f"Only paths under {VIRTUAL_PATH_PREFIX}/, {_get_skills_container_path(app_config)}/, {_ACP_WORKSPACE_VIRTUAL_PATH}/, or configured mount paths are allowed")
def _validate_resolved_user_data_path(resolved: Path, thread_data: ThreadDataState) -> None:
@@ -624,18 +600,23 @@ def _validate_resolved_user_data_path(resolved: Path, thread_data: ThreadDataSta
raise PermissionError("Access denied: path traversal detected")
def _resolve_and_validate_user_data_path(path: str, thread_data: ThreadDataState) -> str:
def _resolve_and_validate_user_data_path(path: str, thread_data: ThreadDataState, app_config: AppConfig) -> str:
"""Resolve a /mnt/user-data virtual path and validate it stays in bounds.
Returns the resolved host path string.
``app_config`` is accepted for signature symmetry with the other resolver
helpers; the user-data resolution path itself is fully derivable from
``thread_data``.
"""
_ = app_config # noqa: F841 — kept for interface symmetry with sibling resolvers
resolved_str = replace_virtual_path(path, thread_data)
resolved = Path(resolved_str).resolve()
_validate_resolved_user_data_path(resolved, thread_data)
return str(resolved)
def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState | None) -> None:
def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState | None, app_config: AppConfig) -> None:
"""Validate absolute paths in local-sandbox bash commands.
This validation is only a best-effort guard for the explicit
@@ -659,7 +640,7 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
raise PermissionError(f"Unsafe file:// URL in command: {file_url_match.group()}. Use paths under {VIRTUAL_PATH_PREFIX}")
unsafe_paths: list[str] = []
allowed_paths = _get_mcp_allowed_paths()
allowed_paths = _get_mcp_allowed_paths(app_config)
for absolute_path in _ABSOLUTE_PATH_PATTERN.findall(command):
# Check for MCP filesystem server allowed paths
@@ -672,7 +653,7 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
continue
# Allow skills container path (resolved by tools.py before passing to sandbox)
if _is_skills_path(absolute_path):
if _is_skills_path(absolute_path, app_config):
_reject_path_traversal(absolute_path)
continue
@@ -682,7 +663,7 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
continue
# Allow custom mount container paths
if _is_custom_mount_path(absolute_path):
if _is_custom_mount_path(absolute_path, app_config):
_reject_path_traversal(absolute_path)
continue
@@ -696,12 +677,13 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
raise PermissionError(f"Unsafe absolute paths in command: {unsafe}. Use paths under {VIRTUAL_PATH_PREFIX}")
def replace_virtual_paths_in_command(command: str, thread_data: ThreadDataState | None) -> str:
def replace_virtual_paths_in_command(command: str, thread_data: ThreadDataState | None, app_config: AppConfig) -> str:
"""Replace all virtual paths (/mnt/user-data, /mnt/skills, /mnt/acp-workspace) in a command string.
Args:
command: The command string that may contain virtual paths.
thread_data: The thread data containing actual paths.
app_config: Resolved application config.
Returns:
The command with all virtual paths replaced.
@@ -709,13 +691,13 @@ def replace_virtual_paths_in_command(command: str, thread_data: ThreadDataState
result = command
# Replace skills paths
skills_container = _get_skills_container_path()
skills_host = _get_skills_host_path()
skills_container = _get_skills_container_path(app_config)
skills_host = _get_skills_host_path(app_config)
if skills_host and skills_container in result:
skills_pattern = re.compile(rf"{re.escape(skills_container)}(/[^\s\"';&|<>()]*)?")
def replace_skills_match(match: re.Match) -> str:
return _resolve_skills_path(match.group(0))
return _resolve_skills_path(match.group(0), app_config)
result = skills_pattern.sub(replace_skills_match, result)
@@ -805,12 +787,10 @@ def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = No
sandbox_id = sandbox_state.get("sandbox_id")
if sandbox_id is None:
raise SandboxRuntimeError("Sandbox ID not found in state")
sandbox = get_sandbox_provider().get(sandbox_id)
sandbox = get_sandbox_provider(resolve_context(runtime).app_config).get(sandbox_id)
if sandbox is None:
raise SandboxNotFoundError(f"Sandbox with ID '{sandbox_id}' not found", sandbox_id=sandbox_id)
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for downstream use
return sandbox
@@ -838,26 +818,24 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
if runtime.state is None:
raise SandboxRuntimeError("Tool runtime state not available")
app_config = runtime.context.app_config
# Check if sandbox already exists in state
sandbox_state = runtime.state.get("sandbox")
if sandbox_state is not None:
sandbox_id = sandbox_state.get("sandbox_id")
if sandbox_id is not None:
sandbox = get_sandbox_provider().get(sandbox_id)
sandbox = get_sandbox_provider(app_config).get(sandbox_id)
if sandbox is not None:
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
return sandbox
# Sandbox was released, fall through to acquire new one
# Lazy acquisition: get thread_id and acquire sandbox
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id is None:
thread_id = runtime.config.get("configurable", {}).get("thread_id") if runtime.config else None
if thread_id is None:
thread_id = runtime.context.thread_id
if not thread_id:
raise SandboxRuntimeError("Thread ID not available in runtime context")
provider = get_sandbox_provider()
provider = get_sandbox_provider(app_config)
sandbox_id = provider.acquire(thread_id)
# Update runtime state - this persists across tool calls
@@ -868,8 +846,6 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
if sandbox is None:
raise SandboxNotFoundError("Sandbox not found after acquisition", sandbox_id=sandbox_id)
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
return sandbox
@@ -999,40 +975,29 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com
description: Explain why you are running this command in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
command: The bash command to execute. Always use absolute paths for files and directories.
"""
app_config = resolve_context(runtime).app_config
try:
sandbox = ensure_sandbox_initialized(runtime)
sandbox_cfg = app_config.sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
if is_local_sandbox(runtime):
if not is_host_bash_allowed():
if not is_host_bash_allowed(app_config):
return f"Error: {LOCAL_HOST_BASH_DISABLED_MESSAGE}"
ensure_thread_directories_exist(runtime)
thread_data = get_thread_data(runtime)
validate_local_bash_command_paths(command, thread_data)
command = replace_virtual_paths_in_command(command, thread_data)
validate_local_bash_command_paths(command, thread_data, app_config)
command = replace_virtual_paths_in_command(command, thread_data, app_config)
command = _apply_cwd_prefix(command, thread_data)
output = sandbox.execute_command(command)
try:
from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
return _truncate_bash_output(mask_local_paths_in_output(output, thread_data), max_chars)
return _truncate_bash_output(mask_local_paths_in_output(output, thread_data, app_config), max_chars)
ensure_thread_directories_exist(runtime)
try:
from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
max_chars = sandbox_cfg.bash_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
return _truncate_bash_output(sandbox.execute_command(command), max_chars)
except SandboxError as e:
return f"Error: {e}"
except PermissionError as e:
return f"Error: {e}"
except Exception as e:
return f"Error: Unexpected error executing command: {_sanitize_error(e, runtime)}"
return f"Error: Unexpected error executing command: {_sanitize_error(e, runtime, app_config)}"
@tool("ls", parse_docstring=True)
@@ -1043,6 +1008,7 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
description: Explain why you are listing this directory in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
path: The **absolute** path to the directory to list.
"""
app_config = resolve_context(runtime).app_config
try:
sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime)
@@ -1050,13 +1016,13 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
thread_data = None
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data, read_only=True)
if _is_skills_path(path):
path = _resolve_skills_path(path)
validate_local_tool_path(path, thread_data, app_config, read_only=True)
if _is_skills_path(path, app_config):
path = _resolve_skills_path(path, app_config)
elif _is_acp_workspace_path(path):
path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
elif not _is_custom_mount_path(path):
path = _resolve_and_validate_user_data_path(path, thread_data)
elif not _is_custom_mount_path(path, app_config):
path = _resolve_and_validate_user_data_path(path, thread_data, app_config)
# Custom mount paths are resolved by LocalSandbox._resolve_path()
children = sandbox.list_dir(path)
if not children:
@@ -1064,13 +1030,8 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
output = "\n".join(children)
if thread_data is not None:
output = mask_local_paths_in_output(output, thread_data)
try:
from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
except Exception:
max_chars = 20000
sandbox_cfg = app_config.sandbox
max_chars = sandbox_cfg.ls_output_max_chars if sandbox_cfg else 20000
return _truncate_ls_output(output, max_chars)
except SandboxError as e:
return f"Error: {e}"
@@ -1079,7 +1040,7 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
except PermissionError:
return f"Error: Permission denied: {requested_path}"
except Exception as e:
return f"Error: Unexpected error listing directory: {_sanitize_error(e, runtime)}"
return f"Error: Unexpected error listing directory: {_sanitize_error(e, runtime, app_config)}"
@tool("glob", parse_docstring=True)
@@ -1100,11 +1061,13 @@ def glob_tool(
include_dirs: Whether matching directories should also be returned. Default is False.
max_results: Maximum number of paths to return. Default is 200.
"""
app_config = resolve_context(runtime).app_config
try:
sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime)
requested_path = path
effective_max_results = _resolve_max_results(
app_config,
"glob",
max_results,
default=_DEFAULT_GLOB_MAX_RESULTS,
@@ -1115,10 +1078,10 @@ def glob_tool(
thread_data = get_thread_data(runtime)
if thread_data is None:
raise SandboxRuntimeError("Thread data not available for local sandbox")
path = _resolve_local_read_path(path, thread_data)
path = _resolve_local_read_path(path, thread_data, app_config)
matches, truncated = sandbox.glob(path, pattern, include_dirs=include_dirs, max_results=effective_max_results)
if thread_data is not None:
matches = [mask_local_paths_in_output(match, thread_data) for match in matches]
matches = [mask_local_paths_in_output(match, thread_data, app_config) for match in matches]
return _format_glob_results(requested_path, matches, truncated)
except SandboxError as e:
return f"Error: {e}"
@@ -1129,7 +1092,7 @@ def glob_tool(
except PermissionError:
return f"Error: Permission denied: {requested_path}"
except Exception as e:
return f"Error: Unexpected error searching paths: {_sanitize_error(e, runtime)}"
return f"Error: Unexpected error searching paths: {_sanitize_error(e, runtime, app_config)}"
@tool("grep", parse_docstring=True)
@@ -1154,11 +1117,13 @@ def grep_tool(
case_sensitive: Whether matching is case-sensitive. Default is False.
max_results: Maximum number of matching lines to return. Default is 100.
"""
app_config = resolve_context(runtime).app_config
try:
sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime)
requested_path = path
effective_max_results = _resolve_max_results(
app_config,
"grep",
max_results,
default=_DEFAULT_GREP_MAX_RESULTS,
@@ -1169,7 +1134,7 @@ def grep_tool(
thread_data = get_thread_data(runtime)
if thread_data is None:
raise SandboxRuntimeError("Thread data not available for local sandbox")
path = _resolve_local_read_path(path, thread_data)
path = _resolve_local_read_path(path, thread_data, app_config)
matches, truncated = sandbox.grep(
path,
pattern,
@@ -1181,7 +1146,7 @@ def grep_tool(
if thread_data is not None:
matches = [
GrepMatch(
path=mask_local_paths_in_output(match.path, thread_data),
path=mask_local_paths_in_output(match.path, thread_data, app_config),
line_number=match.line_number,
line=match.line,
)
@@ -1199,7 +1164,7 @@ def grep_tool(
except PermissionError:
return f"Error: Permission denied: {requested_path}"
except Exception as e:
return f"Error: Unexpected error searching file contents: {_sanitize_error(e, runtime)}"
return f"Error: Unexpected error searching file contents: {_sanitize_error(e, runtime, app_config)}"
@tool("read_file", parse_docstring=True)
@@ -1218,32 +1183,28 @@ def read_file_tool(
start_line: Optional starting line number (1-indexed, inclusive). Use with end_line to read a specific range.
end_line: Optional ending line number (1-indexed, inclusive). Use with start_line to read a specific range.
"""
app_config = resolve_context(runtime).app_config
try:
sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime)
requested_path = path
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data, read_only=True)
if _is_skills_path(path):
path = _resolve_skills_path(path)
validate_local_tool_path(path, thread_data, app_config, read_only=True)
if _is_skills_path(path, app_config):
path = _resolve_skills_path(path, app_config)
elif _is_acp_workspace_path(path):
path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
elif not _is_custom_mount_path(path):
path = _resolve_and_validate_user_data_path(path, thread_data)
elif not _is_custom_mount_path(path, app_config):
path = _resolve_and_validate_user_data_path(path, thread_data, app_config)
# Custom mount paths are resolved by LocalSandbox._resolve_path()
content = sandbox.read_file(path)
if not content:
return "(empty)"
if start_line is not None and end_line is not None:
content = "\n".join(content.splitlines()[start_line - 1 : end_line])
try:
from deerflow.config.app_config import get_app_config
sandbox_cfg = get_app_config().sandbox
max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
except Exception:
max_chars = 50000
sandbox_cfg = app_config.sandbox
max_chars = sandbox_cfg.read_file_output_max_chars if sandbox_cfg else 50000
return _truncate_read_file_output(content, max_chars)
except SandboxError as e:
return f"Error: {e}"
@@ -1254,7 +1215,7 @@ def read_file_tool(
except IsADirectoryError:
return f"Error: Path is a directory, not a file: {requested_path}"
except Exception as e:
return f"Error: Unexpected error reading file: {_sanitize_error(e, runtime)}"
return f"Error: Unexpected error reading file: {_sanitize_error(e, runtime, app_config)}"
@tool("write_file", parse_docstring=True)
@@ -1272,15 +1233,16 @@ def write_file_tool(
path: The **absolute** path to the file to write to. ALWAYS PROVIDE THIS PARAMETER SECOND.
content: The content to write to the file. ALWAYS PROVIDE THIS PARAMETER THIRD.
"""
app_config = resolve_context(runtime).app_config
try:
sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime)
requested_path = path
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data)
if not _is_custom_mount_path(path):
path = _resolve_and_validate_user_data_path(path, thread_data)
validate_local_tool_path(path, thread_data, app_config)
if not _is_custom_mount_path(path, app_config):
path = _resolve_and_validate_user_data_path(path, thread_data, app_config)
# Custom mount paths are resolved by LocalSandbox._resolve_path()
with get_file_operation_lock(sandbox, path):
sandbox.write_file(path, content, append)
@@ -1292,9 +1254,9 @@ def write_file_tool(
except IsADirectoryError:
return f"Error: Path is a directory, not a file: {requested_path}"
except OSError as e:
return f"Error: Failed to write file '{requested_path}': {_sanitize_error(e, runtime)}"
return f"Error: Failed to write file '{requested_path}': {_sanitize_error(e, runtime, app_config)}"
except Exception as e:
return f"Error: Unexpected error writing file: {_sanitize_error(e, runtime)}"
return f"Error: Unexpected error writing file: {_sanitize_error(e, runtime, app_config)}"
@tool("str_replace", parse_docstring=True)
@@ -1316,15 +1278,16 @@ def str_replace_tool(
new_str: The new substring. ALWAYS PROVIDE THIS PARAMETER FOURTH.
replace_all: Whether to replace all occurrences of the substring. If False, only the first occurrence will be replaced. Default is False.
"""
app_config = resolve_context(runtime).app_config
try:
sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime)
requested_path = path
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data)
if not _is_custom_mount_path(path):
path = _resolve_and_validate_user_data_path(path, thread_data)
validate_local_tool_path(path, thread_data, app_config)
if not _is_custom_mount_path(path, app_config):
path = _resolve_and_validate_user_data_path(path, thread_data, app_config)
# Custom mount paths are resolved by LocalSandbox._resolve_path()
with get_file_operation_lock(sandbox, path):
content = sandbox.read_file(path)
@@ -1345,4 +1308,4 @@ def str_replace_tool(
except PermissionError:
return f"Error: Permission denied accessing file: {requested_path}"
except Exception as e:
return f"Error: Unexpected error replacing string: {_sanitize_error(e, runtime)}"
return f"Error: Unexpected error replacing string: {_sanitize_error(e, runtime, app_config)}"
@@ -1,10 +1,14 @@
import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING
from .parser import parse_skill_file
from .types import Skill
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
@@ -22,7 +26,12 @@ def get_skills_root_path() -> Path:
return skills_dir
def load_skills(skills_path: Path | None = None, use_config: bool = True, enabled_only: bool = False) -> list[Skill]:
def load_skills(
app_config: "AppConfig | None" = None,
*,
skills_path: Path | None = None,
enabled_only: bool = False,
) -> list[Skill]:
"""
Load all skills from the skills directory.
@@ -30,25 +39,19 @@ def load_skills(skills_path: Path | None = None, use_config: bool = True, enable
to extract metadata. The enabled state is determined by the skills_state_config.json file.
Args:
skills_path: Optional custom path to skills directory.
If not provided and use_config is True, uses path from config.
Otherwise defaults to deer-flow/skills
use_config: Whether to load skills path from config (default: True)
app_config: Application config used to resolve the configured skills
directory. Ignored when ``skills_path`` is supplied.
skills_path: Explicit override for the skills directory. When both
``skills_path`` and ``app_config`` are omitted the
default repository layout is used (``deer-flow/skills``).
enabled_only: If True, only return enabled skills (default: False)
Returns:
List of Skill objects, sorted by name
"""
if skills_path is None:
if use_config:
try:
from deerflow.config import get_app_config
config = get_app_config()
skills_path = config.skills.get_skills_path()
except Exception:
# Fallback to default if config fails
skills_path = get_skills_root_path()
if app_config is not None:
skills_path = app_config.skills.get_skills_path()
else:
skills_path = get_skills_root_path()
@@ -9,7 +9,7 @@ from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.skills.loader import load_skills
from deerflow.skills.validation import _validate_skill_frontmatter
@@ -20,16 +20,17 @@ ALLOWED_SUPPORT_SUBDIRS = {"references", "templates", "scripts", "assets"}
_SKILL_NAME_PATTERN = re.compile(r"^[a-z0-9]+(?:-[a-z0-9]+)*$")
def get_skills_root_dir() -> Path:
return get_app_config().skills.get_skills_path()
def get_skills_root_dir(app_config: AppConfig) -> Path:
"""Return the configured skills root."""
return app_config.skills.get_skills_path()
def get_public_skills_dir() -> Path:
return get_skills_root_dir() / "public"
def get_public_skills_dir(app_config: AppConfig) -> Path:
return get_skills_root_dir(app_config) / "public"
def get_custom_skills_dir() -> Path:
path = get_skills_root_dir() / "custom"
def get_custom_skills_dir(app_config: AppConfig) -> Path:
path = get_skills_root_dir(app_config) / "custom"
path.mkdir(parents=True, exist_ok=True)
return path
@@ -43,46 +44,46 @@ def validate_skill_name(name: str) -> str:
return normalized
def get_custom_skill_dir(name: str) -> Path:
return get_custom_skills_dir() / validate_skill_name(name)
def get_custom_skill_dir(name: str, app_config: AppConfig) -> Path:
return get_custom_skills_dir(app_config) / validate_skill_name(name)
def get_custom_skill_file(name: str) -> Path:
return get_custom_skill_dir(name) / SKILL_FILE_NAME
def get_custom_skill_file(name: str, app_config: AppConfig) -> Path:
return get_custom_skill_dir(name, app_config) / SKILL_FILE_NAME
def get_custom_skill_history_dir() -> Path:
path = get_custom_skills_dir() / HISTORY_DIR_NAME
def get_custom_skill_history_dir(app_config: AppConfig) -> Path:
path = get_custom_skills_dir(app_config) / HISTORY_DIR_NAME
path.mkdir(parents=True, exist_ok=True)
return path
def get_skill_history_file(name: str) -> Path:
return get_custom_skill_history_dir() / f"{validate_skill_name(name)}.jsonl"
def get_skill_history_file(name: str, app_config: AppConfig) -> Path:
return get_custom_skill_history_dir(app_config) / f"{validate_skill_name(name)}.jsonl"
def get_public_skill_dir(name: str) -> Path:
return get_public_skills_dir() / validate_skill_name(name)
def get_public_skill_dir(name: str, app_config: AppConfig) -> Path:
return get_public_skills_dir(app_config) / validate_skill_name(name)
def custom_skill_exists(name: str) -> bool:
return get_custom_skill_file(name).exists()
def custom_skill_exists(name: str, app_config: AppConfig) -> bool:
return get_custom_skill_file(name, app_config).exists()
def public_skill_exists(name: str) -> bool:
return (get_public_skill_dir(name) / SKILL_FILE_NAME).exists()
def public_skill_exists(name: str, app_config: AppConfig) -> bool:
return (get_public_skill_dir(name, app_config) / SKILL_FILE_NAME).exists()
def ensure_custom_skill_is_editable(name: str) -> None:
if custom_skill_exists(name):
def ensure_custom_skill_is_editable(name: str, app_config: AppConfig) -> None:
if custom_skill_exists(name, app_config):
return
if public_skill_exists(name):
if public_skill_exists(name, app_config):
raise ValueError(f"'{name}' is a built-in skill. To customise it, create a new skill with the same name under skills/custom/.")
raise FileNotFoundError(f"Custom skill '{name}' not found.")
def ensure_safe_support_path(name: str, relative_path: str) -> Path:
skill_dir = get_custom_skill_dir(name).resolve()
def ensure_safe_support_path(name: str, relative_path: str, app_config: AppConfig) -> Path:
skill_dir = get_custom_skill_dir(name, app_config).resolve()
if not relative_path or relative_path.endswith("/"):
raise ValueError("Supporting file path must include a filename.")
relative = Path(relative_path)
@@ -124,8 +125,8 @@ def atomic_write(path: Path, content: str) -> None:
tmp_path.replace(path)
def append_history(name: str, record: dict[str, Any]) -> None:
history_path = get_skill_history_file(name)
def append_history(name: str, record: dict[str, Any], app_config: AppConfig) -> None:
history_path = get_skill_history_file(name, app_config)
history_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"ts": datetime.now(UTC).isoformat(),
@@ -136,8 +137,8 @@ def append_history(name: str, record: dict[str, Any]) -> None:
f.write("\n")
def read_history(name: str) -> list[dict[str, Any]]:
history_path = get_skill_history_file(name)
def read_history(name: str, app_config: AppConfig) -> list[dict[str, Any]]:
history_path = get_skill_history_file(name, app_config)
if not history_path.exists():
return []
records: list[dict[str, Any]] = []
@@ -148,12 +149,12 @@ def read_history(name: str) -> list[dict[str, Any]]:
return records
def list_custom_skills() -> list:
return [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"]
def list_custom_skills(app_config: AppConfig) -> list:
return [skill for skill in load_skills(app_config, enabled_only=False) if skill.category == "custom"]
def read_custom_skill_content(name: str) -> str:
skill_file = get_custom_skill_file(name)
def read_custom_skill_content(name: str, app_config: AppConfig) -> str:
skill_file = get_custom_skill_file(name, app_config)
if not skill_file.exists():
raise FileNotFoundError(f"Custom skill '{name}' not found.")
return skill_file.read_text(encoding="utf-8")
@@ -7,7 +7,7 @@ import logging
import re
from dataclasses import dataclass
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@@ -35,7 +35,7 @@ def _extract_json_object(raw: str) -> dict | None:
return None
async def scan_skill_content(content: str, *, executable: bool = False, location: str = "SKILL.md") -> ScanResult:
async def scan_skill_content(app_config: AppConfig, content: str, *, executable: bool = False, location: str = "SKILL.md") -> ScanResult:
"""Screen skill content before it is written to disk."""
rubric = (
"You are a security reviewer for AI agent skills. "
@@ -47,9 +47,12 @@ async def scan_skill_content(content: str, *, executable: bool = False, location
prompt = f"Location: {location}\nExecutable: {str(executable).lower()}\n\nReview this content:\n-----\n{content}\n-----"
try:
config = get_app_config()
model_name = config.skill_evolution.moderation_model_name
model = create_chat_model(name=model_name, thinking_enabled=False) if model_name else create_chat_model(thinking_enabled=False)
model_name = app_config.skill_evolution.moderation_model_name
model = (
create_chat_model(name=model_name, thinking_enabled=False, app_config=app_config)
if model_name
else create_chat_model(thinking_enabled=False, app_config=app_config)
)
response = await model.ainvoke(
[
{"role": "system", "content": rubric},

Some files were not shown because too many files have changed in this diff Show More