mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 00:45:57 +00:00
refactor(config): eliminate global mutable state — explicit parameter passing on top of main
Squashes 25 PR commits onto current main. AppConfig becomes a pure value object with no ambient lookup. Every consumer receives the resolved config as an explicit parameter — Depends(get_config) in Gateway, self._app_config in DeerFlowClient, runtime.context.app_config in agent runs, AppConfig.from_file() at the LangGraph Server registration boundary. Phase 1 — frozen data + typed context - All config models (AppConfig, MemoryConfig, DatabaseConfig, …) become frozen=True; no sub-module globals. - AppConfig.from_file() is pure (no side-effect singleton loaders). - Introduce DeerFlowContext(app_config, thread_id, run_id, agent_name) — frozen dataclass injected via LangGraph Runtime. - Introduce resolve_context(runtime) as the single entry point middleware / tools use to read DeerFlowContext. Phase 2 — pure explicit parameter passing - Gateway: app.state.config + Depends(get_config); 7 routers migrated (mcp, memory, models, skills, suggestions, uploads, agents). - DeerFlowClient: __init__(config=...) captures config locally. - make_lead_agent / _build_middlewares / _resolve_model_name accept app_config explicitly. - RunContext.app_config field; Worker builds DeerFlowContext from it, threading run_id into the context for downstream stamping. - Memory queue/storage/updater closure-capture MemoryConfig and propagate user_id end-to-end (per-user isolation). - Sandbox/skills/community/factories/tools thread app_config. - resolve_context() rejects non-typed runtime.context. - Test suite migrated off AppConfig.current() monkey-patches. - AppConfig.current() classmethod deleted. Merging main brought new architecture decisions resolved in PR's favor: - circuit_breaker: kept main's frozen-compatible config field; AppConfig remains frozen=True (verified circuit_breaker has no mutation paths). - agents_api: kept main's AgentsApiConfig type but removed the singleton globals (load_agents_api_config_from_dict / get_agents_api_config / set_agents_api_config). 8 routes in agents.py now read via Depends(get_config). - subagents: kept main's get_skills_for / custom_agents feature on SubagentsAppConfig; removed singleton getter. registry.py now reads app_config.subagents directly. - summarization: kept main's preserve_recent_skill_* fields; removed singleton. - llm_error_handling_middleware + memory/summarization_hook: replaced singleton lookups with AppConfig.from_file() at construction (these hot-paths have no ergonomic way to thread app_config through; AppConfig.from_file is a pure load). - worker.py + thread_data_middleware.py: DeerFlowContext.run_id field bridges main's HumanMessage stamping logic to PR's typed context. Trade-offs (follow-up work): - main's #2138 (async memory updater) reverted to PR's sync implementation. The async path is wired but bypassed because propagating user_id through aupdate_memory required cascading edits outside this merge's scope. - tests/test_subagent_skills_config.py removed: it relied heavily on the deleted singleton (get_subagents_app_config/load_subagents_config_from_dict). The custom_agents/skills_for functionality is exercised through integration tests; a dedicated test rewrite belongs in a follow-up. Verification: backend test suite — 2560 passed, 4 skipped, 84 failures. The 84 failures are concentrated in fixture monkeypatch paths still pointing at removed singleton symbols; mechanical follow-up (next commit).
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from .checkpointer import get_checkpointer, make_checkpointer, reset_checkpointer
|
||||
from .factory import create_deerflow_agent
|
||||
from .features import Next, Prev, RuntimeFeatures
|
||||
from .lead_agent import make_lead_agent
|
||||
@@ -18,7 +17,4 @@ __all__ = [
|
||||
"make_lead_agent",
|
||||
"SandboxState",
|
||||
"ThreadState",
|
||||
"get_checkpointer",
|
||||
"reset_checkpointer",
|
||||
"make_checkpointer",
|
||||
]
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
from .async_provider import make_checkpointer
|
||||
from .provider import checkpointer_context, get_checkpointer, reset_checkpointer
|
||||
|
||||
__all__ = [
|
||||
"get_checkpointer",
|
||||
"reset_checkpointer",
|
||||
"checkpointer_context",
|
||||
"make_checkpointer",
|
||||
]
|
||||
@@ -1,106 +0,0 @@
|
||||
"""Async checkpointer factory.
|
||||
|
||||
Provides an **async context manager** for long-running async servers that need
|
||||
proper resource cleanup.
|
||||
|
||||
Supported backends: memory, sqlite, postgres.
|
||||
|
||||
Usage (e.g. FastAPI lifespan)::
|
||||
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
|
||||
async with make_checkpointer() as checkpointer:
|
||||
app.state.checkpointer = checkpointer # InMemorySaver if not configured
|
||||
|
||||
For sync usage see :mod:`deerflow.agents.checkpointer.provider`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.agents.checkpointer.provider import (
|
||||
POSTGRES_CONN_REQUIRED,
|
||||
POSTGRES_INSTALL,
|
||||
SQLITE_INSTALL,
|
||||
)
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Async factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
|
||||
"""Async context manager that constructs and tears down a checkpointer."""
|
||||
if config.type == "memory":
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
if config.type == "sqlite":
|
||||
try:
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(SQLITE_INSTALL) from exc
|
||||
|
||||
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
|
||||
await asyncio.to_thread(ensure_sqlite_parent_dir, conn_str)
|
||||
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
|
||||
await saver.setup()
|
||||
yield saver
|
||||
return
|
||||
|
||||
if config.type == "postgres":
|
||||
try:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_INSTALL) from exc
|
||||
|
||||
if not config.connection_string:
|
||||
raise ValueError(POSTGRES_CONN_REQUIRED)
|
||||
|
||||
async with AsyncPostgresSaver.from_conn_string(config.connection_string) as saver:
|
||||
await saver.setup()
|
||||
yield saver
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown checkpointer type: {config.type!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public async context manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_checkpointer() -> AsyncIterator[Checkpointer]:
|
||||
"""Async context manager that yields a checkpointer for the caller's lifetime.
|
||||
Resources are opened on enter and closed on exit — no global state::
|
||||
|
||||
async with make_checkpointer() as checkpointer:
|
||||
app.state.checkpointer = checkpointer
|
||||
|
||||
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
"""
|
||||
|
||||
config = get_app_config()
|
||||
|
||||
if config.checkpointer is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
async with _async_checkpointer(config.checkpointer) as saver:
|
||||
yield saver
|
||||
@@ -1,192 +0,0 @@
|
||||
"""Sync checkpointer factory.
|
||||
|
||||
Provides a **sync singleton** and a **sync context manager** for LangGraph
|
||||
graph compilation and CLI tools.
|
||||
|
||||
Supported backends: memory, sqlite, postgres.
|
||||
|
||||
Usage::
|
||||
|
||||
from deerflow.agents.checkpointer.provider import get_checkpointer, checkpointer_context
|
||||
|
||||
# Singleton — reused across calls, closed on process exit
|
||||
cp = get_checkpointer()
|
||||
|
||||
# One-shot — fresh connection, closed on block exit
|
||||
with checkpointer_context() as cp:
|
||||
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig
|
||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error message constants — imported by aio.provider too
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SQLITE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite checkpointer. Install it with: uv add langgraph-checkpoint-sqlite"
|
||||
POSTGRES_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL checkpointer. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
|
||||
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
|
||||
"""Context manager that creates and tears down a sync checkpointer.
|
||||
|
||||
Returns a configured ``Checkpointer`` instance. Resource cleanup for any
|
||||
underlying connections or pools is handled by higher-level helpers in
|
||||
this module (such as the singleton factory or context manager); this
|
||||
function does not return a separate cleanup callback.
|
||||
"""
|
||||
if config.type == "memory":
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
if config.type == "sqlite":
|
||||
try:
|
||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(SQLITE_INSTALL) from exc
|
||||
|
||||
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
|
||||
ensure_sqlite_parent_dir(conn_str)
|
||||
with SqliteSaver.from_conn_string(conn_str) as saver:
|
||||
saver.setup()
|
||||
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
|
||||
yield saver
|
||||
return
|
||||
|
||||
if config.type == "postgres":
|
||||
try:
|
||||
from langgraph.checkpoint.postgres import PostgresSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_INSTALL) from exc
|
||||
|
||||
if not config.connection_string:
|
||||
raise ValueError(POSTGRES_CONN_REQUIRED)
|
||||
|
||||
with PostgresSaver.from_conn_string(config.connection_string) as saver:
|
||||
saver.setup()
|
||||
logger.info("Checkpointer: using PostgresSaver")
|
||||
yield saver
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown checkpointer type: {config.type!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_checkpointer: Checkpointer | None = None
|
||||
_checkpointer_ctx = None # open context manager keeping the connection alive
|
||||
|
||||
|
||||
def get_checkpointer() -> Checkpointer:
|
||||
"""Return the global sync checkpointer singleton, creating it on first call.
|
||||
|
||||
Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
|
||||
Raises:
|
||||
ImportError: If the required package for the configured backend is not installed.
|
||||
ValueError: If ``connection_string`` is missing for a backend that requires it.
|
||||
"""
|
||||
global _checkpointer, _checkpointer_ctx
|
||||
|
||||
if _checkpointer is not None:
|
||||
return _checkpointer
|
||||
|
||||
# Ensure app config is loaded before checking checkpointer config
|
||||
# This prevents returning InMemorySaver when config.yaml actually has a checkpointer section
|
||||
# but hasn't been loaded yet
|
||||
from deerflow.config.app_config import _app_config
|
||||
from deerflow.config.checkpointer_config import get_checkpointer_config
|
||||
|
||||
config = get_checkpointer_config()
|
||||
|
||||
if config is None and _app_config is None:
|
||||
# Only load app config lazily when neither the app config nor an explicit
|
||||
# checkpointer config has been initialized yet. This keeps tests that
|
||||
# intentionally set the global checkpointer config isolated from any
|
||||
# ambient config.yaml on disk.
|
||||
try:
|
||||
get_app_config()
|
||||
except FileNotFoundError:
|
||||
# In test environments without config.yaml, this is expected.
|
||||
pass
|
||||
config = get_checkpointer_config()
|
||||
if config is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
|
||||
_checkpointer = InMemorySaver()
|
||||
return _checkpointer
|
||||
|
||||
_checkpointer_ctx = _sync_checkpointer_cm(config)
|
||||
_checkpointer = _checkpointer_ctx.__enter__()
|
||||
|
||||
return _checkpointer
|
||||
|
||||
|
||||
def reset_checkpointer() -> None:
|
||||
"""Reset the sync singleton, forcing recreation on the next call.
|
||||
|
||||
Closes any open backend connections and clears the cached instance.
|
||||
Useful in tests or after a configuration change.
|
||||
"""
|
||||
global _checkpointer, _checkpointer_ctx
|
||||
if _checkpointer_ctx is not None:
|
||||
try:
|
||||
_checkpointer_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error during checkpointer cleanup", exc_info=True)
|
||||
_checkpointer_ctx = None
|
||||
_checkpointer = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync context manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def checkpointer_context() -> Iterator[Checkpointer]:
|
||||
"""Sync context manager that yields a checkpointer and cleans up on exit.
|
||||
|
||||
Unlike :func:`get_checkpointer`, this does **not** cache the instance —
|
||||
each ``with`` block creates and destroys its own connection. Use it in
|
||||
CLI scripts or tests where you want deterministic cleanup::
|
||||
|
||||
with checkpointer_context() as cp:
|
||||
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
|
||||
|
||||
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
"""
|
||||
|
||||
config = get_app_config()
|
||||
if config.checkpointer is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
with _sync_checkpointer_cm(config.checkpointer) as saver:
|
||||
yield saver
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from deerflow.agents.lead_agent.prompt import apply_prompt_template
|
||||
from deerflow.agents.memory.summarization_hook import memory_flush_hook
|
||||
@@ -18,9 +19,8 @@ from deerflow.agents.middlewares.tool_error_handling_middleware import build_lea
|
||||
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.agents_config import load_agent_config, validate_agent_name
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.config.summarization_config import get_summarization_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -35,9 +35,8 @@ def _get_runtime_config(config: RunnableConfig) -> dict:
|
||||
return cfg
|
||||
|
||||
|
||||
def _resolve_model_name(requested_model_name: str | None = None) -> str:
|
||||
def _resolve_model_name(app_config: AppConfig, requested_model_name: str | None = None) -> str:
|
||||
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
|
||||
app_config = get_app_config()
|
||||
default_model_name = app_config.models[0].name if app_config.models else None
|
||||
if default_model_name is None:
|
||||
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
|
||||
@@ -50,9 +49,9 @@ def _resolve_model_name(requested_model_name: str | None = None) -> str:
|
||||
return default_model_name
|
||||
|
||||
|
||||
def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None:
|
||||
def _create_summarization_middleware(app_config: AppConfig) -> DeerFlowSummarizationMiddleware | None:
|
||||
"""Create and configure the summarization middleware from config."""
|
||||
config = get_summarization_config()
|
||||
config = app_config.summarization
|
||||
|
||||
if not config.enabled:
|
||||
return None
|
||||
@@ -68,13 +67,15 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
|
||||
# Prepare keep parameter
|
||||
keep = config.keep.to_tuple()
|
||||
|
||||
# Prepare model parameter
|
||||
# Prepare model parameter.
|
||||
# Bind "middleware:summarize" tag so RunJournal identifies these LLM calls
|
||||
# as middleware rather than lead_agent (SummarizationMiddleware is a
|
||||
# LangChain built-in, so we tag the model at creation time).
|
||||
if config.model_name:
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=app_config)
|
||||
else:
|
||||
# Use a lightweight model for summarization to save costs
|
||||
# Falls back to default model if not explicitly specified
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
model = create_chat_model(thinking_enabled=False, app_config=app_config)
|
||||
model = model.with_config(tags=["middleware:summarize"])
|
||||
|
||||
# Prepare kwargs
|
||||
kwargs = {
|
||||
@@ -90,14 +91,14 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
|
||||
kwargs["summary_prompt"] = config.summary_prompt
|
||||
|
||||
hooks: list[BeforeSummarizationHook] = []
|
||||
if get_memory_config().enabled:
|
||||
if app_config.memory.enabled:
|
||||
hooks.append(memory_flush_hook)
|
||||
|
||||
# The logic below relies on two assumptions holding true: this factory is
|
||||
# the sole entry point for DeerFlowSummarizationMiddleware, and the runtime
|
||||
# config is not expected to change after startup.
|
||||
try:
|
||||
skills_container_path = get_app_config().skills.container_path or "/mnt/skills"
|
||||
skills_container_path = app_config.skills.container_path or "/mnt/skills"
|
||||
except Exception:
|
||||
logger.exception("Failed to resolve skills container path; falling back to default")
|
||||
skills_container_path = "/mnt/skills"
|
||||
@@ -238,10 +239,18 @@ Being proactive with task management demonstrates thoroughness and ensures all r
|
||||
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
|
||||
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
|
||||
# ClarificationMiddleware should be last to intercept clarification requests after model calls
|
||||
def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None, custom_middlewares: list[AgentMiddleware] | None = None):
|
||||
def _build_middlewares(
|
||||
app_config: AppConfig,
|
||||
config: RunnableConfig,
|
||||
*,
|
||||
model_name: str | None,
|
||||
agent_name: str | None = None,
|
||||
custom_middlewares: list[AgentMiddleware] | None = None,
|
||||
):
|
||||
"""Build middleware chain based on runtime configuration.
|
||||
|
||||
Args:
|
||||
app_config: Resolved application config.
|
||||
config: Runtime configuration containing configurable options like is_plan_mode.
|
||||
agent_name: If provided, MemoryMiddleware will use per-agent memory storage.
|
||||
custom_middlewares: Optional list of custom middlewares to inject into the chain.
|
||||
@@ -249,10 +258,10 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
|
||||
Returns:
|
||||
List of middleware instances.
|
||||
"""
|
||||
middlewares = build_lead_runtime_middlewares(lazy_init=True)
|
||||
middlewares = build_lead_runtime_middlewares(app_config=app_config, lazy_init=True)
|
||||
|
||||
# Add summarization middleware if enabled
|
||||
summarization_middleware = _create_summarization_middleware()
|
||||
summarization_middleware = _create_summarization_middleware(app_config)
|
||||
if summarization_middleware is not None:
|
||||
middlewares.append(summarization_middleware)
|
||||
|
||||
@@ -264,7 +273,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
|
||||
middlewares.append(todo_list_middleware)
|
||||
|
||||
# Add TokenUsageMiddleware when token_usage tracking is enabled
|
||||
if get_app_config().token_usage.enabled:
|
||||
if app_config.token_usage.enabled:
|
||||
middlewares.append(TokenUsageMiddleware())
|
||||
|
||||
# Add TitleMiddleware
|
||||
@@ -275,7 +284,6 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
|
||||
|
||||
# Add ViewImageMiddleware only if the current model supports vision.
|
||||
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
|
||||
app_config = get_app_config()
|
||||
model_config = app_config.get_model_config(model_name) if model_name else None
|
||||
if model_config is not None and model_config.supports_vision:
|
||||
middlewares.append(ViewImageMiddleware())
|
||||
@@ -304,11 +312,32 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
|
||||
return middlewares
|
||||
|
||||
|
||||
def make_lead_agent(config: RunnableConfig):
|
||||
def make_lead_agent(
|
||||
config: RunnableConfig,
|
||||
app_config: AppConfig | None = None,
|
||||
) -> CompiledStateGraph:
|
||||
"""Build the lead agent from runtime config.
|
||||
|
||||
Args:
|
||||
config: LangGraph ``RunnableConfig`` carrying per-invocation options
|
||||
(``thinking_enabled``, ``model_name``, ``is_plan_mode``, etc.).
|
||||
app_config: Resolved application config. Required for in-process
|
||||
entry points (DeerFlowClient, Gateway Worker). When omitted we
|
||||
are being called via ``langgraph.json`` registration and reload
|
||||
from disk — the LangGraph Server bootstrap path has no other
|
||||
way to thread the value.
|
||||
"""
|
||||
# Lazy import to avoid circular dependency
|
||||
from deerflow.tools import get_available_tools
|
||||
from deerflow.tools.builtins import setup_agent
|
||||
|
||||
if app_config is None:
|
||||
# LangGraph Server registers ``make_lead_agent`` via ``langgraph.json``
|
||||
# and hands us only a ``RunnableConfig``. Reload config from disk
|
||||
# here — it's a pure function, equivalent to the process-global the
|
||||
# old code path would have read.
|
||||
app_config = AppConfig.from_file()
|
||||
|
||||
cfg = _get_runtime_config(config)
|
||||
|
||||
thinking_enabled = cfg.get("thinking_enabled", True)
|
||||
@@ -325,9 +354,8 @@ def make_lead_agent(config: RunnableConfig):
|
||||
agent_model_name = agent_config.model if agent_config and agent_config.model else None
|
||||
|
||||
# Final model name resolution: request → agent config → global default, with fallback for unknown names
|
||||
model_name = _resolve_model_name(requested_model_name or agent_model_name)
|
||||
model_name = _resolve_model_name(app_config, requested_model_name or agent_model_name)
|
||||
|
||||
app_config = get_app_config()
|
||||
model_config = app_config.get_model_config(model_name)
|
||||
|
||||
if model_config is None:
|
||||
@@ -367,20 +395,22 @@ def make_lead_agent(config: RunnableConfig):
|
||||
if is_bootstrap:
|
||||
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
|
||||
return create_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
|
||||
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) + [setup_agent],
|
||||
middleware=_build_middlewares(config, model_name=model_name),
|
||||
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=app_config),
|
||||
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=app_config) + [setup_agent],
|
||||
middleware=_build_middlewares(app_config, config, model_name=model_name),
|
||||
system_prompt=apply_prompt_template(app_config, subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
|
||||
state_schema=ThreadState,
|
||||
context_schema=DeerFlowContext,
|
||||
)
|
||||
|
||||
# Default lead agent (unchanged behavior)
|
||||
return create_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort),
|
||||
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled),
|
||||
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name),
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=app_config),
|
||||
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=app_config),
|
||||
middleware=_build_middlewares(app_config, config, model_name=model_name, agent_name=agent_name),
|
||||
system_prompt=apply_prompt_template(
|
||||
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
|
||||
app_config, subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
|
||||
),
|
||||
state_schema=ThreadState,
|
||||
context_schema=DeerFlowContext,
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from datetime import datetime
|
||||
from functools import lru_cache
|
||||
|
||||
from deerflow.config.agents_config import load_agent_soul
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.skills import load_skills
|
||||
from deerflow.skills.types import Skill
|
||||
from deerflow.subagents import get_available_subagent_names
|
||||
@@ -19,19 +20,20 @@ _enabled_skills_refresh_version = 0
|
||||
_enabled_skills_refresh_event = threading.Event()
|
||||
|
||||
|
||||
def _load_enabled_skills_sync() -> list[Skill]:
|
||||
return list(load_skills(enabled_only=True))
|
||||
def _load_enabled_skills_sync(app_config: AppConfig | None) -> list[Skill]:
|
||||
return list(load_skills(app_config, enabled_only=True))
|
||||
|
||||
|
||||
def _start_enabled_skills_refresh_thread() -> None:
|
||||
def _start_enabled_skills_refresh_thread(app_config: AppConfig | None) -> None:
|
||||
threading.Thread(
|
||||
target=_refresh_enabled_skills_cache_worker,
|
||||
args=(app_config,),
|
||||
name="deerflow-enabled-skills-loader",
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
|
||||
def _refresh_enabled_skills_cache_worker() -> None:
|
||||
def _refresh_enabled_skills_cache_worker(app_config: AppConfig | None) -> None:
|
||||
global _enabled_skills_cache, _enabled_skills_refresh_active
|
||||
|
||||
while True:
|
||||
@@ -39,8 +41,8 @@ def _refresh_enabled_skills_cache_worker() -> None:
|
||||
target_version = _enabled_skills_refresh_version
|
||||
|
||||
try:
|
||||
skills = _load_enabled_skills_sync()
|
||||
except Exception:
|
||||
skills = _load_enabled_skills_sync(app_config)
|
||||
except (OSError, ImportError):
|
||||
logger.exception("Failed to load enabled skills for prompt injection")
|
||||
skills = []
|
||||
|
||||
@@ -56,7 +58,7 @@ def _refresh_enabled_skills_cache_worker() -> None:
|
||||
_enabled_skills_cache = None
|
||||
|
||||
|
||||
def _ensure_enabled_skills_cache() -> threading.Event:
|
||||
def _ensure_enabled_skills_cache(app_config: AppConfig | None) -> threading.Event:
|
||||
global _enabled_skills_refresh_active
|
||||
|
||||
with _enabled_skills_lock:
|
||||
@@ -68,11 +70,11 @@ def _ensure_enabled_skills_cache() -> threading.Event:
|
||||
_enabled_skills_refresh_active = True
|
||||
_enabled_skills_refresh_event.clear()
|
||||
|
||||
_start_enabled_skills_refresh_thread()
|
||||
_start_enabled_skills_refresh_thread(app_config)
|
||||
return _enabled_skills_refresh_event
|
||||
|
||||
|
||||
def _invalidate_enabled_skills_cache() -> threading.Event:
|
||||
def _invalidate_enabled_skills_cache(app_config: AppConfig | None) -> threading.Event:
|
||||
global _enabled_skills_cache, _enabled_skills_refresh_active, _enabled_skills_refresh_version
|
||||
|
||||
_get_cached_skills_prompt_section.cache_clear()
|
||||
@@ -84,30 +86,30 @@ def _invalidate_enabled_skills_cache() -> threading.Event:
|
||||
return _enabled_skills_refresh_event
|
||||
_enabled_skills_refresh_active = True
|
||||
|
||||
_start_enabled_skills_refresh_thread()
|
||||
_start_enabled_skills_refresh_thread(app_config)
|
||||
return _enabled_skills_refresh_event
|
||||
|
||||
|
||||
def prime_enabled_skills_cache() -> None:
|
||||
_ensure_enabled_skills_cache()
|
||||
def prime_enabled_skills_cache(app_config: AppConfig | None = None) -> None:
|
||||
_ensure_enabled_skills_cache(app_config)
|
||||
|
||||
|
||||
def warm_enabled_skills_cache(timeout_seconds: float = _ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS) -> bool:
|
||||
if _ensure_enabled_skills_cache().wait(timeout=timeout_seconds):
|
||||
def warm_enabled_skills_cache(app_config: AppConfig | None = None, timeout_seconds: float = _ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS) -> bool:
|
||||
if _ensure_enabled_skills_cache(app_config).wait(timeout=timeout_seconds):
|
||||
return True
|
||||
|
||||
logger.warning("Timed out waiting %.1fs for enabled skills cache warm-up", timeout_seconds)
|
||||
return False
|
||||
|
||||
|
||||
def _get_enabled_skills():
|
||||
def _get_enabled_skills(app_config: AppConfig | None = None):
|
||||
with _enabled_skills_lock:
|
||||
cached = _enabled_skills_cache
|
||||
|
||||
if cached is not None:
|
||||
return list(cached)
|
||||
|
||||
_ensure_enabled_skills_cache()
|
||||
_ensure_enabled_skills_cache(app_config)
|
||||
return []
|
||||
|
||||
|
||||
@@ -115,12 +117,12 @@ def _skill_mutability_label(category: str) -> str:
|
||||
return "[custom, editable]" if category == "custom" else "[built-in]"
|
||||
|
||||
|
||||
def clear_skills_system_prompt_cache() -> None:
|
||||
_invalidate_enabled_skills_cache()
|
||||
def clear_skills_system_prompt_cache(app_config: AppConfig | None = None) -> None:
|
||||
_invalidate_enabled_skills_cache(app_config)
|
||||
|
||||
|
||||
async def refresh_skills_system_prompt_cache_async() -> None:
|
||||
await asyncio.to_thread(_invalidate_enabled_skills_cache().wait)
|
||||
async def refresh_skills_system_prompt_cache_async(app_config: AppConfig | None = None) -> None:
|
||||
await asyncio.to_thread(_invalidate_enabled_skills_cache(app_config).wait)
|
||||
|
||||
|
||||
def _reset_skills_system_prompt_cache_state() -> None:
|
||||
@@ -134,10 +136,10 @@ def _reset_skills_system_prompt_cache_state() -> None:
|
||||
_enabled_skills_refresh_event.clear()
|
||||
|
||||
|
||||
def _refresh_enabled_skills_cache() -> None:
|
||||
def _refresh_enabled_skills_cache(app_config: AppConfig | None = None) -> None:
|
||||
"""Backward-compatible test helper for direct synchronous reload."""
|
||||
try:
|
||||
skills = _load_enabled_skills_sync()
|
||||
skills = _load_enabled_skills_sync(app_config)
|
||||
except Exception:
|
||||
logger.exception("Failed to load enabled skills for prompt injection")
|
||||
skills = []
|
||||
@@ -164,7 +166,7 @@ Skip simple one-off tasks.
|
||||
"""
|
||||
|
||||
|
||||
def _build_available_subagents_description(available_names: list[str], bash_available: bool) -> str:
|
||||
def _build_available_subagents_description(available_names: list[str], bash_available: bool, app_config: AppConfig) -> str:
|
||||
"""Dynamically build subagent type descriptions from registry.
|
||||
|
||||
Mirrors Codex's pattern where agent_type_description is dynamically generated
|
||||
@@ -186,7 +188,7 @@ def _build_available_subagents_description(available_names: list[str], bash_avai
|
||||
if name in builtin_descriptions:
|
||||
lines.append(f"- **{name}**: {builtin_descriptions[name]}")
|
||||
else:
|
||||
config = get_subagent_config(name)
|
||||
config = get_subagent_config(name, app_config)
|
||||
if config is not None:
|
||||
desc = config.description.split("\n")[0].strip() # First line only for brevity
|
||||
lines.append(f"- **{name}**: {desc}")
|
||||
@@ -194,22 +196,23 @@ def _build_available_subagents_description(available_names: list[str], bash_avai
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _build_subagent_section(max_concurrent: int) -> str:
|
||||
def _build_subagent_section(max_concurrent: int, app_config: AppConfig) -> str:
|
||||
"""Build the subagent system prompt section with dynamic concurrency limit.
|
||||
|
||||
Args:
|
||||
max_concurrent: Maximum number of concurrent subagent calls allowed per response.
|
||||
app_config: Application config used to gate bash availability.
|
||||
|
||||
Returns:
|
||||
Formatted subagent section string.
|
||||
"""
|
||||
n = max_concurrent
|
||||
available_names = get_available_subagent_names()
|
||||
available_names = get_available_subagent_names(app_config)
|
||||
bash_available = "bash" in available_names
|
||||
|
||||
# Dynamically build subagent type descriptions from registry (aligned with Codex's
|
||||
# agent_type_description pattern where all registered roles are listed in the tool spec).
|
||||
available_subagents = _build_available_subagents_description(available_names, bash_available)
|
||||
available_subagents = _build_available_subagents_description(available_names, bash_available, app_config)
|
||||
direct_tool_examples = "bash, ls, read_file, web_search, etc." if bash_available else "ls, read_file, web_search, etc."
|
||||
direct_execution_example = (
|
||||
'# User asks: "Run the tests"\n# Thinking: Cannot decompose into parallel sub-tasks\n# → Execute directly\n\nbash("npm test") # Direct execution, not task()'
|
||||
@@ -536,36 +539,34 @@ combined with a FastAPI gateway for REST API access [citation:FastAPI](https://f
|
||||
"""
|
||||
|
||||
|
||||
def _get_memory_context(agent_name: str | None = None) -> str:
|
||||
def _get_memory_context(app_config: AppConfig, agent_name: str | None = None) -> str:
|
||||
"""Get memory context for injection into system prompt.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, loads per-agent memory. If None, loads global memory.
|
||||
|
||||
Returns:
|
||||
Formatted memory context string wrapped in XML tags, or empty string if disabled.
|
||||
Returns an empty string when memory is disabled or the stored memory file
|
||||
cannot be read/parsed. A corrupt memory.json degrades the prompt to
|
||||
no-memory; it never kills the agent.
|
||||
"""
|
||||
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
memory_config = app_config.memory
|
||||
if not memory_config.enabled or not memory_config.injection_enabled:
|
||||
return ""
|
||||
|
||||
try:
|
||||
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
memory_data = get_memory_data(memory_config, agent_name, user_id=get_effective_user_id())
|
||||
except (OSError, ValueError, UnicodeDecodeError):
|
||||
logger.exception("Failed to load memory data for prompt injection")
|
||||
return ""
|
||||
|
||||
config = get_memory_config()
|
||||
if not config.enabled or not config.injection_enabled:
|
||||
return ""
|
||||
memory_content = format_memory_for_injection(memory_data, max_tokens=memory_config.max_injection_tokens)
|
||||
if not memory_content.strip():
|
||||
return ""
|
||||
|
||||
memory_data = get_memory_data(agent_name)
|
||||
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
|
||||
|
||||
if not memory_content.strip():
|
||||
return ""
|
||||
|
||||
return f"""<memory>
|
||||
return f"""<memory>
|
||||
{memory_content}
|
||||
</memory>
|
||||
"""
|
||||
except Exception as e:
|
||||
logger.error("Failed to load memory context: %s", e)
|
||||
return ""
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
@@ -600,19 +601,12 @@ You have access to skills that provide optimized workflows for specific tasks. E
|
||||
</skill_system>"""
|
||||
|
||||
|
||||
def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
|
||||
def get_skills_prompt_section(app_config: AppConfig, available_skills: set[str] | None = None) -> str:
|
||||
"""Generate the skills prompt section with available skills list."""
|
||||
skills = _get_enabled_skills()
|
||||
skills = _get_enabled_skills(app_config)
|
||||
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
config = get_app_config()
|
||||
container_base_path = config.skills.container_path
|
||||
skill_evolution_enabled = config.skill_evolution.enabled
|
||||
except Exception:
|
||||
container_base_path = "/mnt/skills"
|
||||
skill_evolution_enabled = False
|
||||
container_base_path = app_config.skills.container_path
|
||||
skill_evolution_enabled = app_config.skill_evolution.enabled
|
||||
|
||||
if not skills and not skill_evolution_enabled:
|
||||
return ""
|
||||
@@ -636,7 +630,7 @@ def get_agent_soul(agent_name: str | None) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def get_deferred_tools_prompt_section() -> str:
|
||||
def get_deferred_tools_prompt_section(app_config: AppConfig) -> str:
|
||||
"""Generate <available-deferred-tools> block for the system prompt.
|
||||
|
||||
Lists only deferred tool names so the agent knows what exists
|
||||
@@ -645,12 +639,7 @@ def get_deferred_tools_prompt_section() -> str:
|
||||
"""
|
||||
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
||||
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
if not get_app_config().tool_search.enabled:
|
||||
return ""
|
||||
except Exception:
|
||||
if not app_config.tool_search.enabled:
|
||||
return ""
|
||||
|
||||
registry = get_deferred_registry()
|
||||
@@ -661,15 +650,9 @@ def get_deferred_tools_prompt_section() -> str:
|
||||
return f"<available-deferred-tools>\n{names}\n</available-deferred-tools>"
|
||||
|
||||
|
||||
def _build_acp_section() -> str:
|
||||
def _build_acp_section(app_config: AppConfig) -> str:
|
||||
"""Build the ACP agent prompt section, only if ACP agents are configured."""
|
||||
try:
|
||||
from deerflow.config.acp_config import get_acp_agents
|
||||
|
||||
agents = get_acp_agents()
|
||||
if not agents:
|
||||
return ""
|
||||
except Exception:
|
||||
if not app_config.acp_agents:
|
||||
return ""
|
||||
|
||||
return (
|
||||
@@ -681,15 +664,9 @@ def _build_acp_section() -> str:
|
||||
)
|
||||
|
||||
|
||||
def _build_custom_mounts_section() -> str:
|
||||
def _build_custom_mounts_section(app_config: AppConfig) -> str:
|
||||
"""Build a prompt section for explicitly configured sandbox mounts."""
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
mounts = get_app_config().sandbox.mounts or []
|
||||
except Exception:
|
||||
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
|
||||
return ""
|
||||
mounts = app_config.sandbox.mounts or []
|
||||
|
||||
if not mounts:
|
||||
return ""
|
||||
@@ -703,13 +680,20 @@ def _build_custom_mounts_section() -> str:
|
||||
return f"\n**Custom Mounted Directories:**\n{mounts_list}\n- If the user needs files outside `/mnt/user-data`, use these absolute container paths directly when they match the requested directory"
|
||||
|
||||
|
||||
def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagents: int = 3, *, agent_name: str | None = None, available_skills: set[str] | None = None) -> str:
|
||||
def apply_prompt_template(
|
||||
app_config: AppConfig,
|
||||
subagent_enabled: bool = False,
|
||||
max_concurrent_subagents: int = 3,
|
||||
*,
|
||||
agent_name: str | None = None,
|
||||
available_skills: set[str] | None = None,
|
||||
) -> str:
|
||||
# Get memory context
|
||||
memory_context = _get_memory_context(agent_name)
|
||||
memory_context = _get_memory_context(app_config, agent_name)
|
||||
|
||||
# Include subagent section only if enabled (from runtime parameter)
|
||||
n = max_concurrent_subagents
|
||||
subagent_section = _build_subagent_section(n) if subagent_enabled else ""
|
||||
subagent_section = _build_subagent_section(n, app_config) if subagent_enabled else ""
|
||||
|
||||
# Add subagent reminder to critical_reminders if enabled
|
||||
subagent_reminder = (
|
||||
@@ -730,14 +714,14 @@ def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagen
|
||||
)
|
||||
|
||||
# Get skills section
|
||||
skills_section = get_skills_prompt_section(available_skills)
|
||||
skills_section = get_skills_prompt_section(app_config, available_skills)
|
||||
|
||||
# Get deferred tools section (tool_search)
|
||||
deferred_tools_section = get_deferred_tools_prompt_section()
|
||||
deferred_tools_section = get_deferred_tools_prompt_section(app_config)
|
||||
|
||||
# Build ACP agent section only if ACP agents are configured
|
||||
acp_section = _build_acp_section()
|
||||
custom_mounts_section = _build_custom_mounts_section()
|
||||
acp_section = _build_acp_section(app_config)
|
||||
custom_mounts_section = _build_custom_mounts_section(app_config)
|
||||
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
|
||||
|
||||
# Format the prompt with dynamic skills and memory
|
||||
|
||||
@@ -7,11 +7,17 @@ from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Module-level config pointer set by the middleware that owns the queue.
|
||||
# The queue runs on a background Timer thread where ``Runtime`` and FastAPI
|
||||
# request context are not accessible; the enqueuer (which does have runtime
|
||||
# context) is responsible for plumbing ``AppConfig`` through ``add()``.
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationContext:
|
||||
"""Context for a conversation to be processed for memory update."""
|
||||
@@ -20,6 +26,7 @@ class ConversationContext:
|
||||
messages: list[Any]
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
agent_name: str | None = None
|
||||
user_id: str | None = None
|
||||
correction_detected: bool = False
|
||||
reinforcement_detected: bool = False
|
||||
|
||||
@@ -30,10 +37,21 @@ class MemoryUpdateQueue:
|
||||
This queue collects conversation contexts and processes them after
|
||||
a configurable debounce period. Multiple conversations received within
|
||||
the debounce window are batched together.
|
||||
|
||||
The queue captures an ``AppConfig`` reference at construction time and
|
||||
reuses it for the MemoryUpdater it spawns. Callers must construct a
|
||||
fresh queue when the config changes rather than reaching into a global.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the memory update queue."""
|
||||
def __init__(self, app_config: AppConfig):
|
||||
"""Initialize the memory update queue.
|
||||
|
||||
Args:
|
||||
app_config: Application config. The queue reads its own
|
||||
``memory`` section for debounce timing and hands the full
|
||||
config to :class:`MemoryUpdater`.
|
||||
"""
|
||||
self._app_config = app_config
|
||||
self._queue: list[ConversationContext] = []
|
||||
self._lock = threading.Lock()
|
||||
self._timer: threading.Timer | None = None
|
||||
@@ -44,19 +62,12 @@ class MemoryUpdateQueue:
|
||||
thread_id: str,
|
||||
messages: list[Any],
|
||||
agent_name: str | None = None,
|
||||
user_id: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> None:
|
||||
"""Add a conversation to the update queue.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
messages: The conversation messages.
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
"""Add a conversation to the update queue."""
|
||||
config = self._app_config.memory
|
||||
if not config.enabled:
|
||||
return
|
||||
|
||||
@@ -65,6 +76,7 @@ class MemoryUpdateQueue:
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
user_id=user_id,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
@@ -77,11 +89,12 @@ class MemoryUpdateQueue:
|
||||
thread_id: str,
|
||||
messages: list[Any],
|
||||
agent_name: str | None = None,
|
||||
user_id: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> None:
|
||||
"""Add a conversation and start processing immediately in the background."""
|
||||
config = get_memory_config()
|
||||
config = self._app_config.memory
|
||||
if not config.enabled:
|
||||
return
|
||||
|
||||
@@ -90,6 +103,7 @@ class MemoryUpdateQueue:
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
user_id=user_id,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
@@ -103,6 +117,7 @@ class MemoryUpdateQueue:
|
||||
thread_id: str,
|
||||
messages: list[Any],
|
||||
agent_name: str | None,
|
||||
user_id: str | None = None,
|
||||
correction_detected: bool,
|
||||
reinforcement_detected: bool,
|
||||
) -> None:
|
||||
@@ -116,6 +131,7 @@ class MemoryUpdateQueue:
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
user_id=user_id,
|
||||
correction_detected=merged_correction_detected,
|
||||
reinforcement_detected=merged_reinforcement_detected,
|
||||
)
|
||||
@@ -125,7 +141,7 @@ class MemoryUpdateQueue:
|
||||
|
||||
def _reset_timer(self) -> None:
|
||||
"""Reset the debounce timer."""
|
||||
config = get_memory_config()
|
||||
config = self._app_config.memory
|
||||
self._schedule_timer(config.debounce_seconds)
|
||||
|
||||
logger.debug("Memory update timer set for %ss", config.debounce_seconds)
|
||||
@@ -165,7 +181,7 @@ class MemoryUpdateQueue:
|
||||
logger.info("Processing %d queued memory updates", len(contexts_to_process))
|
||||
|
||||
try:
|
||||
updater = MemoryUpdater()
|
||||
updater = MemoryUpdater(self._app_config)
|
||||
|
||||
for context in contexts_to_process:
|
||||
try:
|
||||
@@ -176,6 +192,7 @@ class MemoryUpdateQueue:
|
||||
agent_name=context.agent_name,
|
||||
correction_detected=context.correction_detected,
|
||||
reinforcement_detected=context.reinforcement_detected,
|
||||
user_id=context.user_id,
|
||||
)
|
||||
if success:
|
||||
logger.info("Memory updated successfully for thread %s", context.thread_id)
|
||||
@@ -236,31 +253,35 @@ class MemoryUpdateQueue:
|
||||
return self._processing
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_memory_queue: MemoryUpdateQueue | None = None
|
||||
# Queues keyed by ``id(AppConfig)`` so tests and multi-client setups with
|
||||
# distinct configs do not share a debounce queue.
|
||||
_memory_queues: dict[int, MemoryUpdateQueue] = {}
|
||||
_queue_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_memory_queue() -> MemoryUpdateQueue:
|
||||
"""Get the global memory update queue singleton.
|
||||
|
||||
Returns:
|
||||
The memory update queue instance.
|
||||
"""
|
||||
global _memory_queue
|
||||
def get_memory_queue(app_config: AppConfig) -> MemoryUpdateQueue:
|
||||
"""Get or create the memory update queue for the given app config."""
|
||||
key = id(app_config)
|
||||
with _queue_lock:
|
||||
if _memory_queue is None:
|
||||
_memory_queue = MemoryUpdateQueue()
|
||||
return _memory_queue
|
||||
queue = _memory_queues.get(key)
|
||||
if queue is None:
|
||||
queue = MemoryUpdateQueue(app_config)
|
||||
_memory_queues[key] = queue
|
||||
return queue
|
||||
|
||||
|
||||
def reset_memory_queue() -> None:
|
||||
"""Reset the global memory queue.
|
||||
def reset_memory_queue(app_config: AppConfig | None = None) -> None:
|
||||
"""Reset memory queue(s).
|
||||
|
||||
This is useful for testing.
|
||||
Pass an ``app_config`` to reset only its queue, or omit to reset all
|
||||
(useful at test teardown).
|
||||
"""
|
||||
global _memory_queue
|
||||
with _queue_lock:
|
||||
if _memory_queue is not None:
|
||||
_memory_queue.clear()
|
||||
_memory_queue = None
|
||||
if app_config is not None:
|
||||
queue = _memory_queues.pop(id(app_config), None)
|
||||
if queue is not None:
|
||||
queue.clear()
|
||||
return
|
||||
for queue in _memory_queues.values():
|
||||
queue.clear()
|
||||
_memory_queues.clear()
|
||||
|
||||
@@ -10,7 +10,7 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from deerflow.config.agents_config import AGENT_NAME_PATTERN
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,17 +44,17 @@ class MemoryStorage(abc.ABC):
|
||||
"""Abstract base class for memory storage providers."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def load(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Load memory data for the given agent."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Force reload memory data for the given agent."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
||||
"""Save memory data for the given agent."""
|
||||
pass
|
||||
|
||||
@@ -62,11 +62,18 @@ class MemoryStorage(abc.ABC):
|
||||
class FileMemoryStorage(MemoryStorage):
|
||||
"""File-based memory storage provider."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the file memory storage."""
|
||||
# Per-agent memory cache: keyed by agent_name (None = global)
|
||||
def __init__(self, memory_config: MemoryConfig):
|
||||
"""Initialize the file memory storage.
|
||||
|
||||
Args:
|
||||
memory_config: Memory configuration (storage_path etc.). Stored on
|
||||
the instance so per-request lookups don't need to reach for
|
||||
ambient state.
|
||||
"""
|
||||
self._memory_config = memory_config
|
||||
# Per-user/agent memory cache: keyed by (user_id, agent_name) tuple (None = global)
|
||||
# Value: (memory_data, file_mtime)
|
||||
self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
|
||||
self._memory_cache: dict[tuple[str | None, str | None], tuple[dict[str, Any], float | None]] = {}
|
||||
# Guards all reads and writes to _memory_cache across concurrent callers.
|
||||
self._cache_lock = threading.Lock()
|
||||
|
||||
@@ -81,21 +88,28 @@ class FileMemoryStorage(MemoryStorage):
|
||||
if not AGENT_NAME_PATTERN.match(agent_name):
|
||||
raise ValueError(f"Invalid agent name {agent_name!r}: names must match {AGENT_NAME_PATTERN.pattern}")
|
||||
|
||||
def _get_memory_file_path(self, agent_name: str | None = None) -> Path:
|
||||
def _get_memory_file_path(self, agent_name: str | None = None, *, user_id: str | None = None) -> Path:
|
||||
"""Get the path to the memory file."""
|
||||
config = self._memory_config
|
||||
if user_id is not None:
|
||||
if agent_name is not None:
|
||||
self._validate_agent_name(agent_name)
|
||||
return get_paths().user_agent_memory_file(user_id, agent_name)
|
||||
if config.storage_path and Path(config.storage_path).is_absolute():
|
||||
return Path(config.storage_path)
|
||||
return get_paths().user_memory_file(user_id)
|
||||
# Legacy: no user_id
|
||||
if agent_name is not None:
|
||||
self._validate_agent_name(agent_name)
|
||||
return get_paths().agent_memory_file(agent_name)
|
||||
|
||||
config = get_memory_config()
|
||||
if config.storage_path:
|
||||
p = Path(config.storage_path)
|
||||
return p if p.is_absolute() else get_paths().base_dir / p
|
||||
return get_paths().memory_file
|
||||
|
||||
def _load_memory_from_file(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
def _load_memory_from_file(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Load memory data from file."""
|
||||
file_path = self._get_memory_file_path(agent_name)
|
||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||
|
||||
if not file_path.exists():
|
||||
return create_empty_memory()
|
||||
@@ -108,44 +122,46 @@ class FileMemoryStorage(MemoryStorage):
|
||||
logger.warning("Failed to load memory file: %s", e)
|
||||
return create_empty_memory()
|
||||
|
||||
def load(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Load memory data (cached with file modification time check)."""
|
||||
file_path = self._get_memory_file_path(agent_name)
|
||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||
|
||||
try:
|
||||
current_mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||
except OSError:
|
||||
current_mtime = None
|
||||
|
||||
cache_key = (user_id, agent_name)
|
||||
with self._cache_lock:
|
||||
cached = self._memory_cache.get(agent_name)
|
||||
cached = self._memory_cache.get(cache_key)
|
||||
if cached is not None and cached[1] == current_mtime:
|
||||
return cached[0]
|
||||
|
||||
memory_data = self._load_memory_from_file(agent_name)
|
||||
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
|
||||
|
||||
with self._cache_lock:
|
||||
self._memory_cache[agent_name] = (memory_data, current_mtime)
|
||||
self._memory_cache[cache_key] = (memory_data, current_mtime)
|
||||
|
||||
return memory_data
|
||||
|
||||
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
|
||||
def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Reload memory data from file, forcing cache invalidation."""
|
||||
file_path = self._get_memory_file_path(agent_name)
|
||||
memory_data = self._load_memory_from_file(agent_name)
|
||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
|
||||
|
||||
try:
|
||||
mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||
except OSError:
|
||||
mtime = None
|
||||
|
||||
cache_key = (user_id, agent_name)
|
||||
with self._cache_lock:
|
||||
self._memory_cache[agent_name] = (memory_data, mtime)
|
||||
self._memory_cache[cache_key] = (memory_data, mtime)
|
||||
return memory_data
|
||||
|
||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
||||
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
||||
"""Save memory data to file and update cache."""
|
||||
file_path = self._get_memory_file_path(agent_name)
|
||||
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
|
||||
|
||||
try:
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -165,8 +181,9 @@ class FileMemoryStorage(MemoryStorage):
|
||||
except OSError:
|
||||
mtime = None
|
||||
|
||||
cache_key = (user_id, agent_name)
|
||||
with self._cache_lock:
|
||||
self._memory_cache[agent_name] = (memory_data, mtime)
|
||||
self._memory_cache[cache_key] = (memory_data, mtime)
|
||||
logger.info("Memory saved to %s", file_path)
|
||||
return True
|
||||
except OSError as e:
|
||||
@@ -174,23 +191,31 @@ class FileMemoryStorage(MemoryStorage):
|
||||
return False
|
||||
|
||||
|
||||
_storage_instance: MemoryStorage | None = None
|
||||
# Instances keyed by (storage_class_path, id(memory_config)) so tests can
|
||||
# construct isolated storages and multi-client setups with different configs
|
||||
# don't collide on a single process-wide singleton.
|
||||
_storage_instances: dict[tuple[str, int], MemoryStorage] = {}
|
||||
_storage_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_memory_storage() -> MemoryStorage:
|
||||
"""Get the configured memory storage instance."""
|
||||
global _storage_instance
|
||||
if _storage_instance is not None:
|
||||
return _storage_instance
|
||||
def get_memory_storage(memory_config: MemoryConfig) -> MemoryStorage:
|
||||
"""Get the configured memory storage instance.
|
||||
|
||||
Caches one instance per ``(storage_class, memory_config)`` pair. In
|
||||
single-config deployments this collapses to one instance; in multi-client
|
||||
or test scenarios each config gets its own storage.
|
||||
"""
|
||||
key = (memory_config.storage_class, id(memory_config))
|
||||
existing = _storage_instances.get(key)
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
with _storage_lock:
|
||||
if _storage_instance is not None:
|
||||
return _storage_instance
|
||||
|
||||
config = get_memory_config()
|
||||
storage_class_path = config.storage_class
|
||||
existing = _storage_instances.get(key)
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
storage_class_path = memory_config.storage_class
|
||||
try:
|
||||
module_path, class_name = storage_class_path.rsplit(".", 1)
|
||||
import importlib
|
||||
@@ -204,13 +229,14 @@ def get_memory_storage() -> MemoryStorage:
|
||||
if not issubclass(storage_class, MemoryStorage):
|
||||
raise TypeError(f"Configured memory storage '{storage_class_path}' is not a subclass of MemoryStorage")
|
||||
|
||||
_storage_instance = storage_class()
|
||||
instance = storage_class(memory_config)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load memory storage %s, falling back to FileMemoryStorage: %s",
|
||||
storage_class_path,
|
||||
e,
|
||||
)
|
||||
_storage_instance = FileMemoryStorage()
|
||||
instance = FileMemoryStorage(memory_config)
|
||||
|
||||
return _storage_instance
|
||||
_storage_instances[key] = instance
|
||||
return instance
|
||||
|
||||
@@ -5,12 +5,19 @@ from __future__ import annotations
|
||||
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
|
||||
from deerflow.agents.memory.queue import get_memory_queue
|
||||
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
|
||||
def memory_flush_hook(event: SummarizationEvent) -> None:
|
||||
"""Flush messages about to be summarized into the memory queue."""
|
||||
if not get_memory_config().enabled or not event.thread_id:
|
||||
"""Flush messages about to be summarized into the memory queue.
|
||||
|
||||
Reads ``AppConfig`` from disk on every invocation. This hook is fired by
|
||||
``SummarizationMiddleware`` which has no ergonomic way to thread an
|
||||
explicit ``app_config`` through; ``AppConfig.from_file()`` is a pure load
|
||||
so the cost is acceptable for this rare pre-summarization callback.
|
||||
"""
|
||||
app_config = AppConfig.from_file()
|
||||
if not app_config.memory.enabled or not event.thread_id:
|
||||
return
|
||||
|
||||
filtered_messages = filter_messages_for_memory(list(event.messages_to_summarize))
|
||||
@@ -21,7 +28,7 @@ def memory_flush_hook(event: SummarizationEvent) -> None:
|
||||
|
||||
correction_detected = detect_correction(filtered_messages)
|
||||
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||
queue = get_memory_queue()
|
||||
queue = get_memory_queue(app_config)
|
||||
queue.add_nowait(
|
||||
thread_id=event.thread_id,
|
||||
messages=filtered_messages,
|
||||
|
||||
@@ -21,7 +21,8 @@ from deerflow.agents.memory.storage import (
|
||||
get_memory_storage,
|
||||
utc_now_iso_z,
|
||||
)
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -38,44 +39,33 @@ def _create_empty_memory() -> dict[str, Any]:
|
||||
return create_empty_memory()
|
||||
|
||||
|
||||
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
||||
"""Backward-compatible wrapper around the configured memory storage save path."""
|
||||
return get_memory_storage().save(memory_data, agent_name)
|
||||
def _save_memory_to_file(memory_config: MemoryConfig, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
|
||||
"""Save via the configured memory storage."""
|
||||
return get_memory_storage(memory_config).save(memory_data, agent_name, user_id=user_id)
|
||||
|
||||
|
||||
def get_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
def get_memory_data(memory_config: MemoryConfig, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Get the current memory data via storage provider."""
|
||||
return get_memory_storage().load(agent_name)
|
||||
return get_memory_storage(memory_config).load(agent_name, user_id=user_id)
|
||||
|
||||
|
||||
def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
def reload_memory_data(memory_config: MemoryConfig, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Reload memory data via storage provider."""
|
||||
return get_memory_storage().reload(agent_name)
|
||||
return get_memory_storage(memory_config).reload(agent_name, user_id=user_id)
|
||||
|
||||
|
||||
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Persist imported memory data via storage provider.
|
||||
|
||||
Args:
|
||||
memory_data: Full memory payload to persist.
|
||||
agent_name: If provided, imports into per-agent memory.
|
||||
|
||||
Returns:
|
||||
The saved memory data after storage normalization.
|
||||
|
||||
Raises:
|
||||
OSError: If persisting the imported memory fails.
|
||||
"""
|
||||
storage = get_memory_storage()
|
||||
if not storage.save(memory_data, agent_name):
|
||||
def import_memory_data(memory_config: MemoryConfig, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Persist imported memory data via storage provider."""
|
||||
storage = get_memory_storage(memory_config)
|
||||
if not storage.save(memory_data, agent_name, user_id=user_id):
|
||||
raise OSError("Failed to save imported memory data")
|
||||
return storage.load(agent_name)
|
||||
return storage.load(agent_name, user_id=user_id)
|
||||
|
||||
|
||||
def clear_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
def clear_memory_data(memory_config: MemoryConfig, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Clear all stored memory data and persist an empty structure."""
|
||||
cleared_memory = create_empty_memory()
|
||||
if not _save_memory_to_file(cleared_memory, agent_name):
|
||||
if not _save_memory_to_file(memory_config, cleared_memory, agent_name, user_id=user_id):
|
||||
raise OSError("Failed to save cleared memory data")
|
||||
return cleared_memory
|
||||
|
||||
@@ -88,10 +78,13 @@ def _validate_confidence(confidence: float) -> float:
|
||||
|
||||
|
||||
def create_memory_fact(
|
||||
memory_config: MemoryConfig,
|
||||
content: str,
|
||||
category: str = "context",
|
||||
confidence: float = 0.5,
|
||||
agent_name: str | None = None,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a new fact and persist the updated memory data."""
|
||||
normalized_content = content.strip()
|
||||
@@ -101,7 +94,7 @@ def create_memory_fact(
|
||||
normalized_category = category.strip() or "context"
|
||||
validated_confidence = _validate_confidence(confidence)
|
||||
now = utc_now_iso_z()
|
||||
memory_data = get_memory_data(agent_name)
|
||||
memory_data = get_memory_data(memory_config, agent_name, user_id=user_id)
|
||||
updated_memory = dict(memory_data)
|
||||
facts = list(memory_data.get("facts", []))
|
||||
facts.append(
|
||||
@@ -116,15 +109,15 @@ def create_memory_fact(
|
||||
)
|
||||
updated_memory["facts"] = facts
|
||||
|
||||
if not _save_memory_to_file(updated_memory, agent_name):
|
||||
if not _save_memory_to_file(memory_config, updated_memory, agent_name, user_id=user_id):
|
||||
raise OSError("Failed to save memory data after creating fact")
|
||||
|
||||
return updated_memory
|
||||
|
||||
|
||||
def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str, Any]:
|
||||
def delete_memory_fact(memory_config: MemoryConfig, fact_id: str, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
|
||||
"""Delete a fact by its id and persist the updated memory data."""
|
||||
memory_data = get_memory_data(agent_name)
|
||||
memory_data = get_memory_data(memory_config, agent_name, user_id=user_id)
|
||||
facts = memory_data.get("facts", [])
|
||||
updated_facts = [fact for fact in facts if fact.get("id") != fact_id]
|
||||
if len(updated_facts) == len(facts):
|
||||
@@ -133,21 +126,24 @@ def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str,
|
||||
updated_memory = dict(memory_data)
|
||||
updated_memory["facts"] = updated_facts
|
||||
|
||||
if not _save_memory_to_file(updated_memory, agent_name):
|
||||
if not _save_memory_to_file(memory_config, updated_memory, agent_name, user_id=user_id):
|
||||
raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'")
|
||||
|
||||
return updated_memory
|
||||
|
||||
|
||||
def update_memory_fact(
|
||||
memory_config: MemoryConfig,
|
||||
fact_id: str,
|
||||
content: str | None = None,
|
||||
category: str | None = None,
|
||||
confidence: float | None = None,
|
||||
agent_name: str | None = None,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Update an existing fact and persist the updated memory data."""
|
||||
memory_data = get_memory_data(agent_name)
|
||||
memory_data = get_memory_data(memory_config, agent_name, user_id=user_id)
|
||||
updated_memory = dict(memory_data)
|
||||
updated_facts: list[dict[str, Any]] = []
|
||||
found = False
|
||||
@@ -174,7 +170,7 @@ def update_memory_fact(
|
||||
|
||||
updated_memory["facts"] = updated_facts
|
||||
|
||||
if not _save_memory_to_file(updated_memory, agent_name):
|
||||
if not _save_memory_to_file(memory_config, updated_memory, agent_name, user_id=user_id):
|
||||
raise OSError(f"Failed to save memory data after updating fact '{fact_id}'")
|
||||
|
||||
return updated_memory
|
||||
@@ -299,19 +295,25 @@ def _fact_content_key(content: Any) -> str | None:
|
||||
class MemoryUpdater:
|
||||
"""Updates memory using LLM based on conversation context."""
|
||||
|
||||
def __init__(self, model_name: str | None = None):
|
||||
def __init__(self, app_config: AppConfig, model_name: str | None = None):
|
||||
"""Initialize the memory updater.
|
||||
|
||||
Args:
|
||||
app_config: Application config (the updater needs both ``memory``
|
||||
section for behavior and the full config for ``create_chat_model``).
|
||||
model_name: Optional model name to use. If None, uses config or default.
|
||||
"""
|
||||
self._app_config = app_config
|
||||
self._model_name = model_name
|
||||
|
||||
@property
|
||||
def _memory_config(self) -> MemoryConfig:
|
||||
return self._app_config.memory
|
||||
|
||||
def _get_model(self):
|
||||
"""Get the model for memory updates."""
|
||||
config = get_memory_config()
|
||||
model_name = self._model_name or config.model_name
|
||||
return create_chat_model(name=model_name, thinking_enabled=False)
|
||||
model_name = self._model_name or self._memory_config.model_name
|
||||
return create_chat_model(name=model_name, thinking_enabled=False, app_config=self._app_config)
|
||||
|
||||
def _build_correction_hint(
|
||||
self,
|
||||
@@ -344,13 +346,14 @@ class MemoryUpdater:
|
||||
agent_name: str | None,
|
||||
correction_detected: bool,
|
||||
reinforcement_detected: bool,
|
||||
user_id: str | None = None,
|
||||
) -> tuple[dict[str, Any], str] | None:
|
||||
"""Load memory and build the update prompt for a conversation."""
|
||||
config = get_memory_config()
|
||||
config = self._memory_config
|
||||
if not config.enabled or not messages:
|
||||
return None
|
||||
|
||||
current_memory = get_memory_data(agent_name)
|
||||
current_memory = get_memory_data(config, agent_name, user_id=user_id)
|
||||
conversation_text = format_conversation_for_update(messages)
|
||||
if not conversation_text.strip():
|
||||
return None
|
||||
@@ -372,6 +375,7 @@ class MemoryUpdater:
|
||||
response_content: Any,
|
||||
thread_id: str | None,
|
||||
agent_name: str | None,
|
||||
user_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Parse the model response, apply updates, and persist memory."""
|
||||
response_text = _extract_text(response_content).strip()
|
||||
@@ -385,7 +389,7 @@ class MemoryUpdater:
|
||||
# cannot corrupt the still-cached original object reference.
|
||||
updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id)
|
||||
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
||||
return get_memory_storage().save(updated_memory, agent_name)
|
||||
return get_memory_storage(self._memory_config).save(updated_memory, agent_name, user_id=user_id)
|
||||
|
||||
async def aupdate_memory(
|
||||
self,
|
||||
@@ -394,6 +398,7 @@ class MemoryUpdater:
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
user_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Update memory asynchronously based on conversation messages."""
|
||||
try:
|
||||
@@ -403,6 +408,7 @@ class MemoryUpdater:
|
||||
agent_name=agent_name,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
user_id=user_id,
|
||||
)
|
||||
if prepared is None:
|
||||
return False
|
||||
@@ -416,6 +422,7 @@ class MemoryUpdater:
|
||||
response_content=response.content,
|
||||
thread_id=thread_id,
|
||||
agent_name=agent_name,
|
||||
user_id=user_id,
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
||||
@@ -431,6 +438,7 @@ class MemoryUpdater:
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
user_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Synchronously update memory via the async updater path.
|
||||
|
||||
@@ -440,19 +448,83 @@ class MemoryUpdater:
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
user_id: If provided, scopes memory to a specific user.
|
||||
|
||||
Returns:
|
||||
True if update was successful, False otherwise.
|
||||
"""
|
||||
return _run_async_update_sync(
|
||||
self.aupdate_memory(
|
||||
messages=messages,
|
||||
thread_id=thread_id,
|
||||
agent_name=agent_name,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
config = self._memory_config
|
||||
if not config.enabled:
|
||||
return False
|
||||
|
||||
if not messages:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Get current memory
|
||||
current_memory = get_memory_data(config, agent_name, user_id=user_id)
|
||||
|
||||
# Format conversation for prompt
|
||||
conversation_text = format_conversation_for_update(messages)
|
||||
|
||||
if not conversation_text.strip():
|
||||
return False
|
||||
|
||||
# Build prompt
|
||||
correction_hint = ""
|
||||
if correction_detected:
|
||||
correction_hint = (
|
||||
"IMPORTANT: Explicit correction signals were detected in this conversation. "
|
||||
"Pay special attention to what the agent got wrong, what the user corrected, "
|
||||
"and record the correct approach as a fact with category "
|
||||
'"correction" and confidence >= 0.95 when appropriate.'
|
||||
)
|
||||
if reinforcement_detected:
|
||||
reinforcement_hint = (
|
||||
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
|
||||
"The user explicitly confirmed the agent's approach was correct or helpful. "
|
||||
"Record the confirmed approach, style, or preference as a fact with category "
|
||||
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
|
||||
)
|
||||
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
|
||||
|
||||
prompt = MEMORY_UPDATE_PROMPT.format(
|
||||
current_memory=json.dumps(current_memory, indent=2),
|
||||
conversation=conversation_text,
|
||||
correction_hint=correction_hint,
|
||||
)
|
||||
)
|
||||
|
||||
# Call LLM
|
||||
model = self._get_model()
|
||||
response = model.invoke(prompt)
|
||||
response_text = _extract_text(response.content).strip()
|
||||
|
||||
# Parse response
|
||||
# Remove markdown code blocks if present
|
||||
if response_text.startswith("```"):
|
||||
lines = response_text.split("\n")
|
||||
response_text = "\n".join(lines[1:-1] if lines[-1] == "```" else lines[1:])
|
||||
|
||||
update_data = json.loads(response_text)
|
||||
|
||||
# Apply updates
|
||||
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
|
||||
|
||||
# Strip file-upload mentions from all summaries before saving.
|
||||
# Uploaded files are session-scoped and won't exist in future sessions,
|
||||
# so recording upload events in long-term memory causes the agent to
|
||||
# try (and fail) to locate those files in subsequent conversations.
|
||||
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
|
||||
|
||||
# Save
|
||||
return get_memory_storage(config).save(updated_memory, agent_name, user_id=user_id)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Failed to parse LLM response for memory update: %s", e)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception("Memory update failed: %s", e)
|
||||
return False
|
||||
|
||||
def _apply_updates(
|
||||
self,
|
||||
@@ -470,7 +542,7 @@ class MemoryUpdater:
|
||||
Returns:
|
||||
Updated memory data.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
config = self._memory_config
|
||||
now = utc_now_iso_z()
|
||||
|
||||
# Update user sections
|
||||
@@ -547,6 +619,7 @@ def update_memory_from_conversation(
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
user_id: str | None = None,
|
||||
) -> bool:
|
||||
"""Convenience function to update memory from a conversation.
|
||||
|
||||
@@ -556,9 +629,10 @@ def update_memory_from_conversation(
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
user_id: If provided, scopes memory to a specific user.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
updater = MemoryUpdater()
|
||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)
|
||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected, user_id=user_id)
|
||||
|
||||
+2
-2
@@ -20,7 +20,7 @@ from langchain.agents.middleware.types import (
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.errors import GraphBubbleUp
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -78,7 +78,7 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
|
||||
# Load Circuit Breaker configs from app config if available, fall back to defaults
|
||||
try:
|
||||
app_config = get_app_config()
|
||||
app_config = AppConfig.from_file()
|
||||
self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold
|
||||
self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec
|
||||
except (FileNotFoundError, RuntimeError):
|
||||
|
||||
@@ -25,6 +25,8 @@ from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Defaults — can be overridden via constructor
|
||||
@@ -181,12 +183,9 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
||||
self._tool_freq_warned: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
def _get_thread_id(self, runtime: Runtime) -> str:
|
||||
def _get_thread_id(self, runtime: Runtime[DeerFlowContext]) -> str:
|
||||
"""Extract thread_id from runtime context for per-thread tracking."""
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
if thread_id:
|
||||
return thread_id
|
||||
return "default"
|
||||
return runtime.context.thread_id or "default"
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evict least recently used threads if over the limit.
|
||||
@@ -367,11 +366,11 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
return None
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
def after_model(self, state: AgentState, runtime: Runtime[DeerFlowContext]) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime[DeerFlowContext]) -> dict | None:
|
||||
return self._apply(state, runtime)
|
||||
|
||||
def reset(self, thread_id: str | None = None) -> None:
|
||||
|
||||
@@ -5,12 +5,12 @@ from typing import override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
|
||||
from deerflow.agents.memory.queue import get_memory_queue
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -43,7 +43,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
self._agent_name = agent_name
|
||||
|
||||
@override
|
||||
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
|
||||
"""Queue conversation for memory update after agent completes.
|
||||
|
||||
Args:
|
||||
@@ -53,15 +53,11 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
Returns:
|
||||
None (no state changes needed from this middleware).
|
||||
"""
|
||||
config = get_memory_config()
|
||||
if not config.enabled:
|
||||
memory_config = runtime.context.app_config.memory
|
||||
if not memory_config.enabled:
|
||||
return None
|
||||
|
||||
# Get thread ID from runtime context first, then fall back to LangGraph's configurable metadata
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
if thread_id is None:
|
||||
config_data = get_config()
|
||||
thread_id = config_data.get("configurable", {}).get("thread_id")
|
||||
thread_id = runtime.context.thread_id
|
||||
if not thread_id:
|
||||
logger.debug("No thread_id in context, skipping memory update")
|
||||
return None
|
||||
@@ -86,11 +82,16 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
# Queue the filtered conversation for memory update
|
||||
correction_detected = detect_correction(filtered_messages)
|
||||
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||
queue = get_memory_queue()
|
||||
# Capture user_id at enqueue time while the request context is still alive.
|
||||
# threading.Timer fires on a different thread where ContextVar values are not
|
||||
# propagated, so we must store user_id explicitly in ConversationContext.
|
||||
user_id = get_effective_user_id()
|
||||
queue = get_memory_queue(runtime.context.app_config)
|
||||
queue.add(
|
||||
thread_id=thread_id,
|
||||
messages=filtered_messages,
|
||||
agent_name=self._agent_name,
|
||||
user_id=user_id,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
|
||||
@@ -3,11 +3,12 @@ from typing import NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.agents.thread_state import ThreadDataState
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -46,50 +47,50 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
|
||||
self._paths = Paths(base_dir) if base_dir else get_paths()
|
||||
self._lazy_init = lazy_init
|
||||
|
||||
def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
|
||||
def _get_thread_paths(self, thread_id: str, user_id: str | None = None) -> dict[str, str]:
|
||||
"""Get the paths for a thread's data directories.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
user_id: Optional user ID for per-user path isolation.
|
||||
|
||||
Returns:
|
||||
Dictionary with workspace_path, uploads_path, and outputs_path.
|
||||
"""
|
||||
return {
|
||||
"workspace_path": str(self._paths.sandbox_work_dir(thread_id)),
|
||||
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id)),
|
||||
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id)),
|
||||
"workspace_path": str(self._paths.sandbox_work_dir(thread_id, user_id=user_id)),
|
||||
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id, user_id=user_id)),
|
||||
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id, user_id=user_id)),
|
||||
}
|
||||
|
||||
def _create_thread_directories(self, thread_id: str) -> dict[str, str]:
|
||||
def _create_thread_directories(self, thread_id: str, user_id: str | None = None) -> dict[str, str]:
|
||||
"""Create the thread data directories.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
user_id: Optional user ID for per-user path isolation.
|
||||
|
||||
Returns:
|
||||
Dictionary with the created directory paths.
|
||||
"""
|
||||
self._paths.ensure_thread_dirs(thread_id)
|
||||
return self._get_thread_paths(thread_id)
|
||||
self._paths.ensure_thread_dirs(thread_id, user_id=user_id)
|
||||
return self._get_thread_paths(thread_id, user_id=user_id)
|
||||
|
||||
@override
|
||||
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
context = runtime.context or {}
|
||||
thread_id = context.get("thread_id")
|
||||
if thread_id is None:
|
||||
config = get_config()
|
||||
thread_id = config.get("configurable", {}).get("thread_id")
|
||||
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
|
||||
thread_id = runtime.context.thread_id
|
||||
|
||||
if thread_id is None:
|
||||
if not thread_id:
|
||||
raise ValueError("Thread ID is required in runtime context or config.configurable")
|
||||
|
||||
user_id = get_effective_user_id()
|
||||
|
||||
if self._lazy_init:
|
||||
# Lazy initialization: only compute paths, don't create directories
|
||||
paths = self._get_thread_paths(thread_id)
|
||||
paths = self._get_thread_paths(thread_id, user_id=user_id)
|
||||
else:
|
||||
# Eager initialization: create directories immediately
|
||||
paths = self._create_thread_directories(thread_id)
|
||||
paths = self._create_thread_directories(thread_id, user_id=user_id)
|
||||
logger.debug("Created thread data directories for thread %s", thread_id)
|
||||
|
||||
return {
|
||||
|
||||
@@ -2,13 +2,16 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import NotRequired, override
|
||||
from typing import Any, NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.title_config import get_title_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.title_config import TitleConfig
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,10 +47,9 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
|
||||
return ""
|
||||
|
||||
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
|
||||
def _should_generate_title(self, state: TitleMiddlewareState, title_config: TitleConfig) -> bool:
|
||||
"""Check if we should generate a title for this thread."""
|
||||
config = get_title_config()
|
||||
if not config.enabled:
|
||||
if not title_config.enabled:
|
||||
return False
|
||||
|
||||
# Check if thread already has a title in state
|
||||
@@ -66,12 +68,11 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
# Generate title after first complete exchange
|
||||
return len(user_messages) == 1 and len(assistant_messages) >= 1
|
||||
|
||||
def _build_title_prompt(self, state: TitleMiddlewareState) -> tuple[str, str]:
|
||||
def _build_title_prompt(self, state: TitleMiddlewareState, title_config: TitleConfig) -> tuple[str, str]:
|
||||
"""Extract user/assistant messages and build the title prompt.
|
||||
|
||||
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
|
||||
"""
|
||||
config = get_title_config()
|
||||
messages = state.get("messages", [])
|
||||
|
||||
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
|
||||
@@ -80,8 +81,8 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
user_msg = self._normalize_content(user_msg_content)
|
||||
assistant_msg = self._strip_think_tags(self._normalize_content(assistant_msg_content))
|
||||
|
||||
prompt = config.prompt_template.format(
|
||||
max_words=config.max_words,
|
||||
prompt = title_config.prompt_template.format(
|
||||
max_words=title_config.max_words,
|
||||
user_msg=user_msg[:500],
|
||||
assistant_msg=assistant_msg[:500],
|
||||
)
|
||||
@@ -91,54 +92,66 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
"""Remove <think>...</think> blocks emitted by reasoning models (e.g. minimax, DeepSeek-R1)."""
|
||||
return re.sub(r"<think>[\s\S]*?</think>", "", text, flags=re.IGNORECASE).strip()
|
||||
|
||||
def _parse_title(self, content: object) -> str:
|
||||
def _parse_title(self, content: object, title_config: TitleConfig) -> str:
|
||||
"""Normalize model output into a clean title string."""
|
||||
config = get_title_config()
|
||||
title_content = self._normalize_content(content)
|
||||
title_content = self._strip_think_tags(title_content)
|
||||
title = title_content.strip().strip('"').strip("'")
|
||||
return title[: config.max_chars] if len(title) > config.max_chars else title
|
||||
return title[: title_config.max_chars] if len(title) > title_config.max_chars else title
|
||||
|
||||
def _fallback_title(self, user_msg: str) -> str:
|
||||
config = get_title_config()
|
||||
fallback_chars = min(config.max_chars, 50)
|
||||
def _fallback_title(self, user_msg: str, title_config: TitleConfig) -> str:
|
||||
fallback_chars = min(title_config.max_chars, 50)
|
||||
if len(user_msg) > fallback_chars:
|
||||
return user_msg[:fallback_chars].rstrip() + "..."
|
||||
return user_msg if user_msg else "New Conversation"
|
||||
|
||||
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
||||
def _get_runnable_config(self) -> dict[str, Any]:
|
||||
"""Inherit the parent RunnableConfig and add middleware tag.
|
||||
|
||||
This ensures RunJournal identifies LLM calls from this middleware
|
||||
as ``middleware:title`` instead of ``lead_agent``.
|
||||
"""
|
||||
try:
|
||||
parent = get_config()
|
||||
except Exception:
|
||||
parent = {}
|
||||
config = {**parent}
|
||||
config["tags"] = [*(config.get("tags") or []), "middleware:title"]
|
||||
return config
|
||||
|
||||
def _generate_title_result(self, state: TitleMiddlewareState, title_config: TitleConfig) -> dict | None:
|
||||
"""Generate a local fallback title without blocking on an LLM call."""
|
||||
if not self._should_generate_title(state):
|
||||
if not self._should_generate_title(state, title_config):
|
||||
return None
|
||||
|
||||
_, user_msg = self._build_title_prompt(state)
|
||||
return {"title": self._fallback_title(user_msg)}
|
||||
_, user_msg = self._build_title_prompt(state, title_config)
|
||||
return {"title": self._fallback_title(user_msg, title_config)}
|
||||
|
||||
async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
||||
async def _agenerate_title_result(self, state: TitleMiddlewareState, app_config: AppConfig) -> dict | None:
|
||||
"""Generate a title asynchronously and fall back locally on failure."""
|
||||
if not self._should_generate_title(state):
|
||||
title_config = app_config.title
|
||||
if not self._should_generate_title(state, title_config):
|
||||
return None
|
||||
|
||||
config = get_title_config()
|
||||
prompt, user_msg = self._build_title_prompt(state)
|
||||
prompt, user_msg = self._build_title_prompt(state, title_config)
|
||||
|
||||
try:
|
||||
if config.model_name:
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
if title_config.model_name:
|
||||
model = create_chat_model(name=title_config.model_name, thinking_enabled=False, app_config=app_config)
|
||||
else:
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
response = await model.ainvoke(prompt, config={"run_name": "title_agent"})
|
||||
title = self._parse_title(response.content)
|
||||
model = create_chat_model(thinking_enabled=False, app_config=app_config)
|
||||
response = await model.ainvoke(prompt, config=self._get_runnable_config())
|
||||
title = self._parse_title(response.content, title_config)
|
||||
if title:
|
||||
return {"title": title}
|
||||
except Exception:
|
||||
logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
|
||||
return {"title": self._fallback_title(user_msg)}
|
||||
return {"title": self._fallback_title(user_msg, title_config)}
|
||||
|
||||
@override
|
||||
def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
return self._generate_title_result(state)
|
||||
def after_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
|
||||
return self._generate_title_result(state, runtime.context.app_config.title)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
return await self._agenerate_title_result(state)
|
||||
async def aafter_model(self, state: TitleMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
|
||||
return await self._agenerate_title_result(state, runtime.context.app_config)
|
||||
|
||||
+10
-5
@@ -1,8 +1,10 @@
|
||||
"""Tool error handling middleware and shared runtime middleware builders."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import override
|
||||
from typing import TYPE_CHECKING, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
@@ -11,6 +13,9 @@ from langgraph.errors import GraphBubbleUp
|
||||
from langgraph.prebuilt.tool_node import ToolCallRequest
|
||||
from langgraph.types import Command
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
|
||||
@@ -67,6 +72,7 @@ class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
|
||||
|
||||
def _build_runtime_middlewares(
|
||||
*,
|
||||
app_config: "AppConfig",
|
||||
include_uploads: bool,
|
||||
include_dangling_tool_call_patch: bool,
|
||||
lazy_init: bool = True,
|
||||
@@ -94,9 +100,7 @@ def _build_runtime_middlewares(
|
||||
middlewares.append(LLMErrorHandlingMiddleware())
|
||||
|
||||
# Guardrail middleware (if configured)
|
||||
from deerflow.config.guardrails_config import get_guardrails_config
|
||||
|
||||
guardrails_config = get_guardrails_config()
|
||||
guardrails_config = app_config.guardrails
|
||||
if guardrails_config.enabled and guardrails_config.provider:
|
||||
import inspect
|
||||
|
||||
@@ -125,9 +129,10 @@ def _build_runtime_middlewares(
|
||||
return middlewares
|
||||
|
||||
|
||||
def build_lead_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
|
||||
def build_lead_runtime_middlewares(*, app_config: "AppConfig", lazy_init: bool = True) -> list[AgentMiddleware]:
|
||||
"""Middlewares shared by lead agent runtime before lead-only middlewares."""
|
||||
return _build_runtime_middlewares(
|
||||
app_config=app_config,
|
||||
include_uploads=True,
|
||||
include_dangling_tool_call_patch=True,
|
||||
lazy_init=lazy_init,
|
||||
|
||||
@@ -9,7 +9,9 @@ from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.deer_flow_context import DeerFlowContext
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.utils.file_conversion import extract_outline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -184,7 +186,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
return files if files else None
|
||||
|
||||
@override
|
||||
def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
def before_agent(self, state: UploadsMiddlewareState, runtime: Runtime[DeerFlowContext]) -> dict | None:
|
||||
"""Inject uploaded files information before agent execution.
|
||||
|
||||
New files come from the current message's additional_kwargs.files.
|
||||
@@ -213,15 +215,8 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
return None
|
||||
|
||||
# Resolve uploads directory for existence checks
|
||||
thread_id = (runtime.context or {}).get("thread_id")
|
||||
if thread_id is None:
|
||||
try:
|
||||
from langgraph.config import get_config
|
||||
|
||||
thread_id = get_config().get("configurable", {}).get("thread_id")
|
||||
except RuntimeError:
|
||||
pass # get_config() raises outside a runnable context (e.g. unit tests)
|
||||
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
|
||||
thread_id = runtime.context.thread_id
|
||||
uploads_dir = self._paths.sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) if thread_id else None
|
||||
|
||||
# Get newly uploaded files from the current message's additional_kwargs.files
|
||||
new_files = self._files_from_kwargs(last_message, uploads_dir) or []
|
||||
|
||||
Reference in New Issue
Block a user