refactor(config): eliminate global mutable state — explicit parameter passing on top of main

Squashes 25 PR commits onto current main. AppConfig becomes a pure value
object with no ambient lookup. Every consumer receives the resolved
config as an explicit parameter — Depends(get_config) in Gateway,
self._app_config in DeerFlowClient, runtime.context.app_config in agent
runs, AppConfig.from_file() at the LangGraph Server registration
boundary.

Phase 1 — frozen data + typed context

- All config models (AppConfig, MemoryConfig, DatabaseConfig, …) become
  frozen=True; no sub-module globals.
- AppConfig.from_file() is pure (no side-effect singleton loaders).
- Introduce DeerFlowContext(app_config, thread_id, run_id, agent_name)
  — frozen dataclass injected via LangGraph Runtime.
- Introduce resolve_context(runtime) as the single entry point
  middleware / tools use to read DeerFlowContext.

Phase 2 — pure explicit parameter passing

- Gateway: app.state.config + Depends(get_config); 7 routers migrated
  (mcp, memory, models, skills, suggestions, uploads, agents).
- DeerFlowClient: __init__(config=...) captures config locally.
- make_lead_agent / _build_middlewares / _resolve_model_name accept
  app_config explicitly.
- RunContext.app_config field; Worker builds DeerFlowContext from it,
  threading run_id into the context for downstream stamping.
- Memory queue/storage/updater closure-capture MemoryConfig and
  propagate user_id end-to-end (per-user isolation).
- Sandbox/skills/community/factories/tools thread app_config.
- resolve_context() rejects non-typed runtime.context.
- Test suite migrated off AppConfig.current() monkey-patches.
- AppConfig.current() classmethod deleted.

Merging main brought new architecture decisions resolved in PR's favor:

- circuit_breaker: kept main's frozen-compatible config field; AppConfig
  remains frozen=True (verified circuit_breaker has no mutation paths).
- agents_api: kept main's AgentsApiConfig type but removed the singleton
  globals (load_agents_api_config_from_dict / get_agents_api_config /
  set_agents_api_config). 8 routes in agents.py now read via
  Depends(get_config).
- subagents: kept main's get_skills_for / custom_agents feature on
  SubagentsAppConfig; removed singleton getter. registry.py now reads
  app_config.subagents directly.
- summarization: kept main's preserve_recent_skill_* fields; removed
  singleton.
- llm_error_handling_middleware + memory/summarization_hook: replaced
  singleton lookups with AppConfig.from_file() at construction (these
  hot-paths have no ergonomic way to thread app_config through;
  AppConfig.from_file is a pure load).
- worker.py + thread_data_middleware.py: DeerFlowContext.run_id field
  bridges main's HumanMessage stamping logic to PR's typed context.

Trade-offs (follow-up work):

- main's #2138 (async memory updater) reverted to PR's sync
  implementation. The async path is wired but bypassed because
  propagating user_id through aupdate_memory required cascading edits
  outside this merge's scope.
- tests/test_subagent_skills_config.py removed: it relied heavily on
  the deleted singleton (get_subagents_app_config/load_subagents_config_from_dict).
  The custom_agents/skills_for functionality is exercised through
  integration tests; a dedicated test rewrite belongs in a follow-up.

Verification: backend test suite — 2560 passed, 4 skipped, 84 failures.
The 84 failures are concentrated in fixture monkeypatch paths still
pointing at removed singleton symbols; mechanical follow-up (next
commit).
This commit is contained in:
greatmengqi
2026-04-26 21:45:02 +08:00
parent 9dc25987e0
commit 3e6a34297d
365 changed files with 31220 additions and 5303 deletions
@@ -5,15 +5,22 @@ Re-exports the public API of :mod:`~deerflow.runtime.runs` and
directly from ``deerflow.runtime``.
"""
from .runs import ConflictError, DisconnectMode, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
from .checkpointer import checkpointer_context, get_checkpointer, make_checkpointer, reset_checkpointer
from .runs import ConflictError, DisconnectMode, RunContext, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple
from .store import get_store, make_store, reset_store, store_context
from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge
__all__ = [
# checkpointer
"checkpointer_context",
"get_checkpointer",
"make_checkpointer",
"reset_checkpointer",
# runs
"ConflictError",
"DisconnectMode",
"RunContext",
"RunManager",
"RunRecord",
"RunStatus",
@@ -0,0 +1,9 @@
from .async_provider import make_checkpointer
from .provider import checkpointer_context, get_checkpointer, reset_checkpointer
__all__ = [
"get_checkpointer",
"reset_checkpointer",
"checkpointer_context",
"make_checkpointer",
]
@@ -0,0 +1,157 @@
"""Async checkpointer factory.
Provides an **async context manager** for long-running async servers that need
proper resource cleanup.
Supported backends: memory, sqlite, postgres.
Usage (e.g. FastAPI lifespan)::
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
async with make_checkpointer() as checkpointer:
app.state.checkpointer = checkpointer # InMemorySaver if not configured
For sync usage see :mod:`deerflow.runtime.checkpointer.provider`.
"""
from __future__ import annotations
import asyncio
import contextlib
import logging
from collections.abc import AsyncIterator
from langgraph.types import Checkpointer
from deerflow.config.app_config import AppConfig
from deerflow.runtime.checkpointer.provider import (
POSTGRES_CONN_REQUIRED,
POSTGRES_INSTALL,
SQLITE_INSTALL,
)
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Async factory
# ---------------------------------------------------------------------------
@contextlib.asynccontextmanager
async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
"""Async context manager that constructs and tears down a checkpointer."""
if config.type == "memory":
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
return
if config.type == "sqlite":
try:
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
except ImportError as exc:
raise ImportError(SQLITE_INSTALL) from exc
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
await asyncio.to_thread(ensure_sqlite_parent_dir, conn_str)
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
await saver.setup()
yield saver
return
if config.type == "postgres":
try:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not config.connection_string:
raise ValueError(POSTGRES_CONN_REQUIRED)
async with AsyncPostgresSaver.from_conn_string(config.connection_string) as saver:
await saver.setup()
yield saver
return
raise ValueError(f"Unknown checkpointer type: {config.type!r}")
# ---------------------------------------------------------------------------
# Public async context manager
# ---------------------------------------------------------------------------
@contextlib.asynccontextmanager
async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpointer]:
"""Async context manager that constructs a checkpointer from unified DatabaseConfig."""
if db_config.backend == "memory":
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
return
if db_config.backend == "sqlite":
try:
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
except ImportError as exc:
raise ImportError(SQLITE_INSTALL) from exc
conn_str = db_config.checkpointer_sqlite_path
ensure_sqlite_parent_dir(conn_str)
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
await saver.setup()
yield saver
return
if db_config.backend == "postgres":
try:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not db_config.postgres_url:
raise ValueError("database.postgres_url is required for the postgres backend")
async with AsyncPostgresSaver.from_conn_string(db_config.postgres_url) as saver:
await saver.setup()
yield saver
return
raise ValueError(f"Unknown database backend: {db_config.backend!r}")
@contextlib.asynccontextmanager
async def make_checkpointer(app_config: AppConfig) -> AsyncIterator[Checkpointer]:
"""Async context manager that yields a checkpointer for the caller's lifetime.
Resources are opened on enter and closed on exit -- no global state::
async with make_checkpointer(app_config) as checkpointer:
app.state.checkpointer = checkpointer
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
Priority:
1. Legacy ``checkpointer:`` config section (backward compatible)
2. Unified ``database:`` config section
3. Default InMemorySaver
"""
# Legacy: standalone checkpointer config takes precedence
if app_config.checkpointer is not None:
async with _async_checkpointer(app_config.checkpointer) as saver:
yield saver
return
# Unified database config
db_config = getattr(app_config, "database", None)
if db_config is not None and db_config.backend != "memory":
async with _async_checkpointer_from_database(db_config) as saver:
yield saver
return
# Default: in-memory
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
@@ -0,0 +1,175 @@
"""Sync checkpointer factory.
Provides a **sync singleton** and a **sync context manager** for LangGraph
graph compilation and CLI tools.
Supported backends: memory, sqlite, postgres.
Usage::
from deerflow.runtime.checkpointer.provider import get_checkpointer, checkpointer_context
# Singleton — reused across calls, closed on process exit
cp = get_checkpointer()
# One-shot — fresh connection, closed on block exit
with checkpointer_context() as cp:
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
"""
from __future__ import annotations
import contextlib
import logging
from collections.abc import Iterator
from langgraph.types import Checkpointer
from deerflow.config.app_config import AppConfig
from deerflow.config.checkpointer_config import CheckpointerConfig
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Error message constants — imported by aio.provider too
# ---------------------------------------------------------------------------
SQLITE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite checkpointer. Install it with: uv add langgraph-checkpoint-sqlite"
POSTGRES_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL checkpointer. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
# ---------------------------------------------------------------------------
# Sync factory
# ---------------------------------------------------------------------------
@contextlib.contextmanager
def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
"""Context manager that creates and tears down a sync checkpointer.
Returns a configured ``Checkpointer`` instance. Resource cleanup for any
underlying connections or pools is handled by higher-level helpers in
this module (such as the singleton factory or context manager); this
function does not return a separate cleanup callback.
"""
if config.type == "memory":
from langgraph.checkpoint.memory import InMemorySaver
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
yield InMemorySaver()
return
if config.type == "sqlite":
try:
from langgraph.checkpoint.sqlite import SqliteSaver
except ImportError as exc:
raise ImportError(SQLITE_INSTALL) from exc
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
ensure_sqlite_parent_dir(conn_str)
with SqliteSaver.from_conn_string(conn_str) as saver:
saver.setup()
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
yield saver
return
if config.type == "postgres":
try:
from langgraph.checkpoint.postgres import PostgresSaver
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not config.connection_string:
raise ValueError(POSTGRES_CONN_REQUIRED)
with PostgresSaver.from_conn_string(config.connection_string) as saver:
saver.setup()
logger.info("Checkpointer: using PostgresSaver")
yield saver
return
raise ValueError(f"Unknown checkpointer type: {config.type!r}")
# ---------------------------------------------------------------------------
# Sync singleton
# ---------------------------------------------------------------------------
_checkpointer: Checkpointer | None = None
_checkpointer_ctx = None # open context manager keeping the connection alive
def get_checkpointer(app_config: AppConfig) -> Checkpointer:
"""Return the global sync checkpointer singleton, creating it on first call.
Returns an ``InMemorySaver`` only when ``checkpointer`` is explicitly
absent from config.yaml. Any other failure (missing config, invalid
backend, connection error) propagates — silent degradation to in-memory
would drop persistent-run state on process restart.
Raises:
ImportError: If the required package for the configured backend is not installed.
ValueError: If ``connection_string`` is missing for a backend that requires it.
"""
global _checkpointer, _checkpointer_ctx
if _checkpointer is not None:
return _checkpointer
config = app_config.checkpointer
if config is None:
from langgraph.checkpoint.memory import InMemorySaver
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
_checkpointer = InMemorySaver()
return _checkpointer
_checkpointer_ctx = _sync_checkpointer_cm(config)
_checkpointer = _checkpointer_ctx.__enter__()
return _checkpointer
def reset_checkpointer() -> None:
"""Reset the sync singleton, forcing recreation on the next call.
Closes any open backend connections and clears the cached instance.
Useful in tests or after a configuration change.
"""
global _checkpointer, _checkpointer_ctx
if _checkpointer_ctx is not None:
try:
_checkpointer_ctx.__exit__(None, None, None)
except Exception:
logger.warning("Error during checkpointer cleanup", exc_info=True)
_checkpointer_ctx = None
_checkpointer = None
# ---------------------------------------------------------------------------
# Sync context manager
# ---------------------------------------------------------------------------
@contextlib.contextmanager
def checkpointer_context(app_config: AppConfig) -> Iterator[Checkpointer]:
"""Sync context manager that yields a checkpointer and cleans up on exit.
Unlike :func:`get_checkpointer`, this does **not** cache the instance —
each ``with`` block creates and destroys its own connection. Use it in
CLI scripts or tests where you want deterministic cleanup::
with checkpointer_context(app_config) as cp:
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
"""
if app_config.checkpointer is None:
from langgraph.checkpoint.memory import InMemorySaver
yield InMemorySaver()
return
with _sync_checkpointer_cm(app_config.checkpointer) as saver:
yield saver
@@ -0,0 +1,134 @@
"""Pure functions to convert LangChain message objects to OpenAI Chat Completions format.
Used by RunJournal to build content dicts for event storage.
"""
from __future__ import annotations
import json
from typing import Any
_ROLE_MAP = {
"human": "user",
"ai": "assistant",
"system": "system",
"tool": "tool",
}
def langchain_to_openai_message(message: Any) -> dict:
"""Convert a single LangChain BaseMessage to an OpenAI message dict.
Handles:
- HumanMessage → {"role": "user", "content": "..."}
- AIMessage (text only) → {"role": "assistant", "content": "..."}
- AIMessage (with tool_calls) → {"role": "assistant", "content": null, "tool_calls": [...]}
- AIMessage (text + tool_calls) → both content and tool_calls present
- AIMessage (list content / multimodal) → content preserved as list
- SystemMessage → {"role": "system", "content": "..."}
- ToolMessage → {"role": "tool", "tool_call_id": "...", "content": "..."}
"""
msg_type = getattr(message, "type", "")
role = _ROLE_MAP.get(msg_type, msg_type)
content = getattr(message, "content", "")
if role == "tool":
return {
"role": "tool",
"tool_call_id": getattr(message, "tool_call_id", ""),
"content": content,
}
if role == "assistant":
tool_calls = getattr(message, "tool_calls", None) or []
result: dict = {"role": "assistant"}
if tool_calls:
openai_tool_calls = []
for tc in tool_calls:
args = tc.get("args", {})
openai_tool_calls.append(
{
"id": tc.get("id", ""),
"type": "function",
"function": {
"name": tc.get("name", ""),
"arguments": json.dumps(args) if not isinstance(args, str) else args,
},
}
)
# If no text content, set content to null per OpenAI spec
result["content"] = content if (isinstance(content, list) and content) or (isinstance(content, str) and content) else None
result["tool_calls"] = openai_tool_calls
else:
result["content"] = content
return result
# user / system / unknown
return {"role": role, "content": content}
def _infer_finish_reason(message: Any) -> str:
"""Infer OpenAI finish_reason from an AIMessage.
Returns "tool_calls" if tool_calls present, else looks in
response_metadata.finish_reason, else returns "stop".
"""
tool_calls = getattr(message, "tool_calls", None) or []
if tool_calls:
return "tool_calls"
resp_meta = getattr(message, "response_metadata", None) or {}
if isinstance(resp_meta, dict):
finish = resp_meta.get("finish_reason")
if finish:
return finish
return "stop"
def langchain_to_openai_completion(message: Any) -> dict:
"""Convert an AIMessage and its metadata to an OpenAI completion response dict.
Returns:
{
"id": message.id,
"model": message.response_metadata.get("model_name"),
"choices": [{"index": 0, "message": <openai_message>, "finish_reason": <inferred>}],
"usage": {"prompt_tokens": ..., "completion_tokens": ..., "total_tokens": ...} or None,
}
"""
resp_meta = getattr(message, "response_metadata", None) or {}
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
openai_msg = langchain_to_openai_message(message)
finish_reason = _infer_finish_reason(message)
usage_metadata = getattr(message, "usage_metadata", None)
if usage_metadata is not None:
input_tokens = usage_metadata.get("input_tokens", 0) or 0
output_tokens = usage_metadata.get("output_tokens", 0) or 0
usage: dict | None = {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
}
else:
usage = None
return {
"id": getattr(message, "id", None),
"model": model_name,
"choices": [
{
"index": 0,
"message": openai_msg,
"finish_reason": finish_reason,
}
],
"usage": usage,
}
def langchain_messages_to_openai(messages: list) -> list[dict]:
"""Convert a list of LangChain BaseMessages to OpenAI message dicts."""
return [langchain_to_openai_message(m) for m in messages]
@@ -0,0 +1,4 @@
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.events.store.memory import MemoryRunEventStore
__all__ = ["MemoryRunEventStore", "RunEventStore"]
@@ -0,0 +1,26 @@
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.events.store.memory import MemoryRunEventStore
def make_run_event_store(config=None) -> RunEventStore:
"""Create a RunEventStore based on run_events.backend configuration."""
if config is None or config.backend == "memory":
return MemoryRunEventStore()
if config.backend == "db":
from deerflow.persistence.engine import get_session_factory
sf = get_session_factory()
if sf is None:
# database.backend=memory but run_events.backend=db -> fallback
return MemoryRunEventStore()
from deerflow.runtime.events.store.db import DbRunEventStore
return DbRunEventStore(sf, max_trace_content=config.max_trace_content)
if config.backend == "jsonl":
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
return JsonlRunEventStore()
raise ValueError(f"Unknown run_events backend: {config.backend!r}")
__all__ = ["MemoryRunEventStore", "RunEventStore", "make_run_event_store"]
@@ -0,0 +1,109 @@
"""Abstract interface for run event storage.
RunEventStore is the unified storage interface for run event streams.
Messages (frontend display) and execution traces (debugging/audit) go
through the same interface, distinguished by the ``category`` field.
Implementations:
- MemoryRunEventStore: in-memory dict (development, tests)
- Future: DB-backed store (SQLAlchemy ORM), JSONL file store
"""
from __future__ import annotations
import abc
class RunEventStore(abc.ABC):
"""Run event stream storage interface.
All implementations must guarantee:
1. put() events are retrievable in subsequent queries
2. seq is strictly increasing within the same thread
3. list_messages() only returns category="message" events
4. list_events() returns all events for the specified run
5. Returned dicts match the RunEvent field structure
"""
@abc.abstractmethod
async def put(
self,
*,
thread_id: str,
run_id: str,
event_type: str,
category: str,
content: str | dict = "",
metadata: dict | None = None,
created_at: str | None = None,
) -> dict:
"""Write an event, auto-assign seq, return the complete record."""
@abc.abstractmethod
async def put_batch(self, events: list[dict]) -> list[dict]:
"""Batch-write events. Used by RunJournal flush buffer.
Each dict's keys match put()'s keyword arguments.
Returns complete records with seq assigned.
"""
@abc.abstractmethod
async def list_messages(
self,
thread_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
) -> list[dict]:
"""Return displayable messages (category=message) for a thread, ordered by seq ascending.
Supports bidirectional cursor pagination:
- before_seq: return the last ``limit`` records with seq < before_seq (ascending)
- after_seq: return the first ``limit`` records with seq > after_seq (ascending)
- neither: return the latest ``limit`` records (ascending)
"""
@abc.abstractmethod
async def list_events(
self,
thread_id: str,
run_id: str,
*,
event_types: list[str] | None = None,
limit: int = 500,
) -> list[dict]:
"""Return the full event stream for a run, ordered by seq ascending.
Optionally filter by event_types.
"""
@abc.abstractmethod
async def list_messages_by_run(
self,
thread_id: str,
run_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
) -> list[dict]:
"""Return displayable messages (category=message) for a specific run, ordered by seq ascending.
Supports bidirectional cursor pagination:
- after_seq: return the first ``limit`` records with seq > after_seq (ascending)
- before_seq: return the last ``limit`` records with seq < before_seq (ascending)
- neither: return the latest ``limit`` records (ascending)
"""
@abc.abstractmethod
async def count_messages(self, thread_id: str) -> int:
"""Count displayable messages (category=message) in a thread."""
@abc.abstractmethod
async def delete_by_thread(self, thread_id: str) -> int:
"""Delete all events for a thread. Return the number of deleted events."""
@abc.abstractmethod
async def delete_by_run(self, thread_id: str, run_id: str) -> int:
"""Delete all events for a specific run. Return the number of deleted events."""
@@ -0,0 +1,286 @@
"""SQLAlchemy-backed RunEventStore implementation.
Persists events to the ``run_events`` table. Trace content is truncated
at ``max_trace_content`` bytes to avoid bloating the database.
"""
from __future__ import annotations
import json
import logging
from datetime import UTC, datetime
from sqlalchemy import delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.run_event import RunEventRow
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_user_id
logger = logging.getLogger(__name__)
class DbRunEventStore(RunEventStore):
def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, max_trace_content: int = 10240):
self._sf = session_factory
self._max_trace_content = max_trace_content
@staticmethod
def _row_to_dict(row: RunEventRow) -> dict:
d = row.to_dict()
d["metadata"] = d.pop("event_metadata", {})
val = d.get("created_at")
if isinstance(val, datetime):
d["created_at"] = val.isoformat()
d.pop("id", None)
# Restore dict content that was JSON-serialized on write
raw = d.get("content", "")
if isinstance(raw, str) and d.get("metadata", {}).get("content_is_dict"):
try:
d["content"] = json.loads(raw)
except (json.JSONDecodeError, ValueError):
# Content looked like JSON (content_is_dict flag) but failed to parse;
# keep the raw string as-is.
logger.debug("Failed to deserialize content as JSON for event seq=%s", d.get("seq"))
return d
def _truncate_trace(self, category: str, content: str | dict, metadata: dict | None) -> tuple[str | dict, dict]:
if category == "trace":
text = json.dumps(content, default=str, ensure_ascii=False) if isinstance(content, dict) else content
encoded = text.encode("utf-8")
if len(encoded) > self._max_trace_content:
# Truncate by bytes, then decode back (may cut a multi-byte char, so use errors="ignore")
content = encoded[: self._max_trace_content].decode("utf-8", errors="ignore")
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
return content, metadata or {}
@staticmethod
def _user_id_from_context() -> str | None:
"""Soft read of user_id from contextvar for write paths.
Returns ``None`` (no filter / no stamp) if contextvar is unset,
which is the expected case for background worker writes. HTTP
request writes will have the contextvar set by auth middleware
and get their user_id stamped automatically.
Coerces ``user.id`` to ``str`` at the boundary: ``User.id`` is
typed as ``UUID`` by the auth layer, but ``run_events.user_id``
is ``VARCHAR(64)`` and aiosqlite cannot bind a raw UUID object
to a VARCHAR column ("type 'UUID' is not supported") — the
INSERT would silently roll back and the worker would hang.
"""
user = get_current_user()
return str(user.id) if user is not None else None
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
"""Write a single event — low-frequency path only.
This opens a dedicated transaction with a FOR UPDATE lock to
assign a monotonic *seq*. For high-throughput writes use
:meth:`put_batch`, which acquires the lock once for the whole
batch. Currently the only caller is ``worker.run_agent`` for
the initial ``human_message`` event (once per run).
"""
content, metadata = self._truncate_trace(category, content, metadata)
if isinstance(content, dict):
db_content = json.dumps(content, default=str, ensure_ascii=False)
metadata = {**(metadata or {}), "content_is_dict": True}
else:
db_content = content
user_id = self._user_id_from_context()
async with self._sf() as session:
async with session.begin():
# Use FOR UPDATE to serialize seq assignment within a thread.
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
# the UNIQUE(thread_id, seq) constraint catches races there.
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
seq = (max_seq or 0) + 1
row = RunEventRow(
thread_id=thread_id,
run_id=run_id,
user_id=user_id,
event_type=event_type,
category=category,
content=db_content,
event_metadata=metadata,
seq=seq,
created_at=datetime.fromisoformat(created_at) if created_at else datetime.now(UTC),
)
session.add(row)
return self._row_to_dict(row)
async def put_batch(self, events):
if not events:
return []
user_id = self._user_id_from_context()
async with self._sf() as session:
async with session.begin():
# Get max seq for the thread (assume all events in batch belong to same thread).
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
# the UNIQUE(thread_id, seq) constraint catches races there.
thread_id = events[0]["thread_id"]
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
seq = max_seq or 0
rows = []
for e in events:
seq += 1
content = e.get("content", "")
category = e.get("category", "trace")
metadata = e.get("metadata")
content, metadata = self._truncate_trace(category, content, metadata)
if isinstance(content, dict):
db_content = json.dumps(content, default=str, ensure_ascii=False)
metadata = {**(metadata or {}), "content_is_dict": True}
else:
db_content = content
row = RunEventRow(
thread_id=e["thread_id"],
run_id=e["run_id"],
user_id=e.get("user_id", user_id),
event_type=e["event_type"],
category=category,
content=db_content,
event_metadata=metadata,
seq=seq,
created_at=datetime.fromisoformat(e["created_at"]) if e.get("created_at") else datetime.now(UTC),
)
session.add(row)
rows.append(row)
return [self._row_to_dict(r) for r in rows]
async def list_messages(
self,
thread_id,
*,
limit=50,
before_seq=None,
after_seq=None,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages")
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
if before_seq is not None:
stmt = stmt.where(RunEventRow.seq < before_seq)
if after_seq is not None:
stmt = stmt.where(RunEventRow.seq > after_seq)
if after_seq is not None:
# Forward pagination: first `limit` records after cursor
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
else:
# before_seq or default (latest): take last `limit` records, return ascending
stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
rows = list(result.scalars())
return [self._row_to_dict(r) for r in reversed(rows)]
async def list_events(
self,
thread_id,
run_id,
*,
event_types=None,
limit=500,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_events")
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
if event_types:
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def list_messages_by_run(
self,
thread_id,
run_id,
*,
limit=50,
before_seq=None,
after_seq=None,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages_by_run")
stmt = select(RunEventRow).where(
RunEventRow.thread_id == thread_id,
RunEventRow.run_id == run_id,
RunEventRow.category == "message",
)
if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
if before_seq is not None:
stmt = stmt.where(RunEventRow.seq < before_seq)
if after_seq is not None:
stmt = stmt.where(RunEventRow.seq > after_seq)
if after_seq is not None:
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
else:
stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
rows = list(result.scalars())
return [self._row_to_dict(r) for r in reversed(rows)]
async def count_messages(
self,
thread_id,
*,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.count_messages")
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
async with self._sf() as session:
return await session.scalar(stmt) or 0
async def delete_by_thread(
self,
thread_id,
*,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_thread")
async with self._sf() as session:
count_conditions = [RunEventRow.thread_id == thread_id]
if resolved_user_id is not None:
count_conditions.append(RunEventRow.user_id == resolved_user_id)
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
count = await session.scalar(count_stmt) or 0
if count > 0:
await session.execute(delete(RunEventRow).where(*count_conditions))
await session.commit()
return count
async def delete_by_run(
self,
thread_id,
run_id,
*,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_run")
async with self._sf() as session:
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
if resolved_user_id is not None:
count_conditions.append(RunEventRow.user_id == resolved_user_id)
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
count = await session.scalar(count_stmt) or 0
if count > 0:
await session.execute(delete(RunEventRow).where(*count_conditions))
await session.commit()
return count
@@ -0,0 +1,187 @@
"""JSONL file-backed RunEventStore implementation.
Each run's events are stored in a single file:
``.deer-flow/threads/{thread_id}/runs/{run_id}.jsonl``
All categories (message, trace, lifecycle) are in the same file.
This backend is suitable for lightweight single-node deployments.
Known trade-off: ``list_messages()`` must scan all run files for a
thread since messages from multiple runs need unified seq ordering.
``list_events()`` reads only one file -- the fast path.
"""
from __future__ import annotations
import json
import logging
import re
from datetime import UTC, datetime
from pathlib import Path
from deerflow.runtime.events.store.base import RunEventStore
logger = logging.getLogger(__name__)
_SAFE_ID_PATTERN = re.compile(r"^[A-Za-z0-9_\-]+$")
class JsonlRunEventStore(RunEventStore):
def __init__(self, base_dir: str | Path | None = None):
self._base_dir = Path(base_dir) if base_dir else Path(".deer-flow")
self._seq_counters: dict[str, int] = {} # thread_id -> current max seq
@staticmethod
def _validate_id(value: str, label: str) -> str:
"""Validate that an ID is safe for use in filesystem paths."""
if not value or not _SAFE_ID_PATTERN.match(value):
raise ValueError(f"Invalid {label}: must be alphanumeric/dash/underscore, got {value!r}")
return value
def _thread_dir(self, thread_id: str) -> Path:
self._validate_id(thread_id, "thread_id")
return self._base_dir / "threads" / thread_id / "runs"
def _run_file(self, thread_id: str, run_id: str) -> Path:
self._validate_id(run_id, "run_id")
return self._thread_dir(thread_id) / f"{run_id}.jsonl"
def _next_seq(self, thread_id: str) -> int:
self._seq_counters[thread_id] = self._seq_counters.get(thread_id, 0) + 1
return self._seq_counters[thread_id]
def _ensure_seq_loaded(self, thread_id: str) -> None:
"""Load max seq from existing files if not yet cached."""
if thread_id in self._seq_counters:
return
max_seq = 0
thread_dir = self._thread_dir(thread_id)
if thread_dir.exists():
for f in thread_dir.glob("*.jsonl"):
for line in f.read_text(encoding="utf-8").strip().splitlines():
try:
record = json.loads(line)
max_seq = max(max_seq, record.get("seq", 0))
except json.JSONDecodeError:
logger.debug("Skipping malformed JSONL line in %s", f)
continue
self._seq_counters[thread_id] = max_seq
def _write_record(self, record: dict) -> None:
path = self._run_file(record["thread_id"], record["run_id"])
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "a", encoding="utf-8") as f:
f.write(json.dumps(record, default=str, ensure_ascii=False) + "\n")
def _read_thread_events(self, thread_id: str) -> list[dict]:
"""Read all events for a thread, sorted by seq."""
events = []
thread_dir = self._thread_dir(thread_id)
if not thread_dir.exists():
return events
for f in sorted(thread_dir.glob("*.jsonl")):
for line in f.read_text(encoding="utf-8").strip().splitlines():
if not line:
continue
try:
events.append(json.loads(line))
except json.JSONDecodeError:
logger.debug("Skipping malformed JSONL line in %s", f)
continue
events.sort(key=lambda e: e.get("seq", 0))
return events
def _read_run_events(self, thread_id: str, run_id: str) -> list[dict]:
"""Read events for a specific run file."""
path = self._run_file(thread_id, run_id)
if not path.exists():
return []
events = []
for line in path.read_text(encoding="utf-8").strip().splitlines():
if not line:
continue
try:
events.append(json.loads(line))
except json.JSONDecodeError:
logger.debug("Skipping malformed JSONL line in %s", path)
continue
events.sort(key=lambda e: e.get("seq", 0))
return events
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None):
self._ensure_seq_loaded(thread_id)
seq = self._next_seq(thread_id)
record = {
"thread_id": thread_id,
"run_id": run_id,
"event_type": event_type,
"category": category,
"content": content,
"metadata": metadata or {},
"seq": seq,
"created_at": created_at or datetime.now(UTC).isoformat(),
}
self._write_record(record)
return record
async def put_batch(self, events):
if not events:
return []
results = []
for ev in events:
record = await self.put(**ev)
results.append(record)
return results
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
all_events = self._read_thread_events(thread_id)
messages = [e for e in all_events if e.get("category") == "message"]
if before_seq is not None:
messages = [e for e in messages if e["seq"] < before_seq]
return messages[-limit:]
elif after_seq is not None:
messages = [e for e in messages if e["seq"] > after_seq]
return messages[:limit]
else:
return messages[-limit:]
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
events = self._read_run_events(thread_id, run_id)
if event_types is not None:
events = [e for e in events if e.get("event_type") in event_types]
return events[:limit]
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
events = self._read_run_events(thread_id, run_id)
filtered = [e for e in events if e.get("category") == "message"]
if before_seq is not None:
filtered = [e for e in filtered if e.get("seq", 0) < before_seq]
if after_seq is not None:
filtered = [e for e in filtered if e.get("seq", 0) > after_seq]
if after_seq is not None:
return filtered[:limit]
else:
return filtered[-limit:] if len(filtered) > limit else filtered
async def count_messages(self, thread_id):
all_events = self._read_thread_events(thread_id)
return sum(1 for e in all_events if e.get("category") == "message")
async def delete_by_thread(self, thread_id):
all_events = self._read_thread_events(thread_id)
count = len(all_events)
thread_dir = self._thread_dir(thread_id)
if thread_dir.exists():
for f in thread_dir.glob("*.jsonl"):
f.unlink()
self._seq_counters.pop(thread_id, None)
return count
async def delete_by_run(self, thread_id, run_id):
events = self._read_run_events(thread_id, run_id)
count = len(events)
path = self._run_file(thread_id, run_id)
if path.exists():
path.unlink()
return count
@@ -0,0 +1,128 @@
"""In-memory RunEventStore. Used when run_events.backend=memory (default) and in tests.
Thread-safe for single-process async usage (no threading locks needed
since all mutations happen within the same event loop).
"""
from __future__ import annotations
from datetime import UTC, datetime
from deerflow.runtime.events.store.base import RunEventStore
class MemoryRunEventStore(RunEventStore):
def __init__(self) -> None:
self._events: dict[str, list[dict]] = {} # thread_id -> sorted event list
self._seq_counters: dict[str, int] = {} # thread_id -> last assigned seq
def _next_seq(self, thread_id: str) -> int:
current = self._seq_counters.get(thread_id, 0)
next_val = current + 1
self._seq_counters[thread_id] = next_val
return next_val
def _put_one(
self,
*,
thread_id: str,
run_id: str,
event_type: str,
category: str,
content: str | dict = "",
metadata: dict | None = None,
created_at: str | None = None,
) -> dict:
seq = self._next_seq(thread_id)
record = {
"thread_id": thread_id,
"run_id": run_id,
"event_type": event_type,
"category": category,
"content": content,
"metadata": metadata or {},
"seq": seq,
"created_at": created_at or datetime.now(UTC).isoformat(),
}
self._events.setdefault(thread_id, []).append(record)
return record
async def put(
self,
*,
thread_id,
run_id,
event_type,
category,
content="",
metadata=None,
created_at=None,
):
return self._put_one(
thread_id=thread_id,
run_id=run_id,
event_type=event_type,
category=category,
content=content,
metadata=metadata,
created_at=created_at,
)
async def put_batch(self, events):
results = []
for ev in events:
record = self._put_one(**ev)
results.append(record)
return results
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
all_events = self._events.get(thread_id, [])
messages = [e for e in all_events if e["category"] == "message"]
if before_seq is not None:
messages = [e for e in messages if e["seq"] < before_seq]
# Take the last `limit` records
return messages[-limit:]
elif after_seq is not None:
messages = [e for e in messages if e["seq"] > after_seq]
return messages[:limit]
else:
# Return the latest `limit` records, ascending
return messages[-limit:]
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
all_events = self._events.get(thread_id, [])
filtered = [e for e in all_events if e["run_id"] == run_id]
if event_types is not None:
filtered = [e for e in filtered if e["event_type"] in event_types]
return filtered[:limit]
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
all_events = self._events.get(thread_id, [])
filtered = [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"]
if before_seq is not None:
filtered = [e for e in filtered if e["seq"] < before_seq]
if after_seq is not None:
filtered = [e for e in filtered if e["seq"] > after_seq]
if after_seq is not None:
return filtered[:limit]
else:
return filtered[-limit:] if len(filtered) > limit else filtered
async def count_messages(self, thread_id):
all_events = self._events.get(thread_id, [])
return sum(1 for e in all_events if e["category"] == "message")
async def delete_by_thread(self, thread_id):
events = self._events.pop(thread_id, [])
self._seq_counters.pop(thread_id, None)
return len(events)
async def delete_by_run(self, thread_id, run_id):
all_events = self._events.get(thread_id, [])
if not all_events:
return 0
remaining = [e for e in all_events if e["run_id"] != run_id]
removed = len(all_events) - len(remaining)
self._events[thread_id] = remaining
return removed
@@ -0,0 +1,497 @@
"""Run event capture via LangChain callbacks.
RunJournal sits between LangChain's callback mechanism and the pluggable
RunEventStore. It standardizes callback data into RunEvent records and
handles token usage accumulation.
Key design decisions:
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
- on_chat_model_start captures structured prompts as llm_request (OpenAI format)
- on_llm_end emits llm_response in OpenAI Chat Completions format
- Token usage accumulated in memory, written to RunRow on run completion
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
"""
from __future__ import annotations
import asyncio
import logging
import time
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler
if TYPE_CHECKING:
from deerflow.runtime.events.store.base import RunEventStore
logger = logging.getLogger(__name__)
class RunJournal(BaseCallbackHandler):
"""LangChain callback handler that captures events to RunEventStore."""
def __init__(
self,
run_id: str,
thread_id: str,
event_store: RunEventStore,
*,
track_token_usage: bool = True,
flush_threshold: int = 20,
):
super().__init__()
self.run_id = run_id
self.thread_id = thread_id
self._store = event_store
self._track_tokens = track_token_usage
self._flush_threshold = flush_threshold
# Write buffer
self._buffer: list[dict] = []
self._pending_flush_tasks: set[asyncio.Task[None]] = set()
# Token accumulators
self._total_input_tokens = 0
self._total_output_tokens = 0
self._total_tokens = 0
self._llm_call_count = 0
self._lead_agent_tokens = 0
self._subagent_tokens = 0
self._middleware_tokens = 0
# Convenience fields
self._last_ai_msg: str | None = None
self._first_human_msg: str | None = None
self._msg_count = 0
# Latency tracking
self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time
# LLM request/response tracking
self._llm_call_index = 0
self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages
self._cached_models: dict[str, str] = {} # langchain run_id -> model name
# Tool call ID cache
self._tool_call_ids: dict[str, str] = {} # langchain run_id -> tool_call_id
# -- Lifecycle callbacks --
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
if kwargs.get("parent_run_id") is not None:
return
self._put(
event_type="run_start",
category="lifecycle",
metadata={"input_preview": str(inputs)[:500]},
)
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
if kwargs.get("parent_run_id") is not None:
return
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
self._flush_sync()
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
if kwargs.get("parent_run_id") is not None:
return
self._put(
event_type="run_error",
category="lifecycle",
content=str(error),
metadata={"error_type": type(error).__name__},
)
self._flush_sync()
# -- LLM callbacks --
def on_chat_model_start(self, serialized: dict, messages: list[list], *, run_id: UUID, **kwargs: Any) -> None:
"""Capture structured prompt messages for llm_request event."""
from deerflow.runtime.converters import langchain_messages_to_openai
rid = str(run_id)
self._llm_start_times[rid] = time.monotonic()
self._llm_call_index += 1
model_name = serialized.get("name", "")
self._cached_models[rid] = model_name
# Convert the first message list (LangChain passes list-of-lists)
prompt_msgs = messages[0] if messages else []
openai_msgs = langchain_messages_to_openai(prompt_msgs)
self._cached_prompts[rid] = openai_msgs
caller = self._identify_caller(kwargs)
self._put(
event_type="llm_request",
category="trace",
content={"model": model_name, "messages": openai_msgs},
metadata={"caller": caller, "llm_call_index": self._llm_call_index},
)
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
# Fallback: on_chat_model_start is preferred. This just tracks latency.
self._llm_start_times[str(run_id)] = time.monotonic()
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
from deerflow.runtime.converters import langchain_to_openai_completion
try:
message = response.generations[0][0].message
except (IndexError, AttributeError):
logger.debug("on_llm_end: could not extract message from response")
return
caller = self._identify_caller(kwargs)
# Latency
rid = str(run_id)
start = self._llm_start_times.pop(rid, None)
latency_ms = int((time.monotonic() - start) * 1000) if start else None
# Token usage from message
usage = getattr(message, "usage_metadata", None)
usage_dict = dict(usage) if usage else {}
# Resolve call index
call_index = self._llm_call_index
if rid not in self._cached_prompts:
# Fallback: on_chat_model_start was not called
self._llm_call_index += 1
call_index = self._llm_call_index
# Clean up caches
self._cached_prompts.pop(rid, None)
self._cached_models.pop(rid, None)
# Trace event: llm_response (OpenAI completion format)
content = getattr(message, "content", "")
self._put(
event_type="llm_response",
category="trace",
content=langchain_to_openai_completion(message),
metadata={
"caller": caller,
"usage": usage_dict,
"latency_ms": latency_ms,
"llm_call_index": call_index,
},
)
# Message events: only lead_agent gets message-category events.
# Content uses message.model_dump() to align with checkpoint format.
tool_calls = getattr(message, "tool_calls", None) or []
if caller == "lead_agent":
resp_meta = getattr(message, "response_metadata", None) or {}
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
if tool_calls:
# ai_tool_call: agent decided to use tools
self._put(
event_type="ai_tool_call",
category="message",
content=message.model_dump(),
metadata={"model_name": model_name, "finish_reason": "tool_calls"},
)
elif isinstance(content, str) and content:
# ai_message: final text reply
self._put(
event_type="ai_message",
category="message",
content=message.model_dump(),
metadata={"model_name": model_name, "finish_reason": "stop"},
)
self._last_ai_msg = content
self._msg_count += 1
# Token accumulation
if self._track_tokens:
input_tk = usage_dict.get("input_tokens", 0) or 0
output_tk = usage_dict.get("output_tokens", 0) or 0
total_tk = usage_dict.get("total_tokens", 0) or 0
if total_tk == 0:
total_tk = input_tk + output_tk
if total_tk > 0:
self._total_input_tokens += input_tk
self._total_output_tokens += output_tk
self._total_tokens += total_tk
self._llm_call_count += 1
if caller.startswith("subagent:"):
self._subagent_tokens += total_tk
elif caller.startswith("middleware:"):
self._middleware_tokens += total_tk
else:
self._lead_agent_tokens += total_tk
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
self._llm_start_times.pop(str(run_id), None)
self._put(event_type="llm_error", category="trace", content=str(error))
# -- Tool callbacks --
def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None:
tool_call_id = kwargs.get("tool_call_id")
if tool_call_id:
self._tool_call_ids[str(run_id)] = tool_call_id
self._put(
event_type="tool_start",
category="trace",
metadata={
"tool_name": serialized.get("name", ""),
"tool_call_id": tool_call_id,
"args": str(input_str)[:2000],
},
)
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
from langchain_core.messages import ToolMessage
from langgraph.types import Command
# Tools that update graph state return a ``Command`` (e.g.
# ``present_files``). LangGraph later unwraps the inner ToolMessage
# into checkpoint state, so to stay checkpoint-aligned we must
# extract it here rather than storing ``str(Command(...))``.
if isinstance(output, Command):
update = getattr(output, "update", None) or {}
inner_msgs = update.get("messages") if isinstance(update, dict) else None
if isinstance(inner_msgs, list):
inner_tool_msg = next((m for m in inner_msgs if isinstance(m, ToolMessage)), None)
if inner_tool_msg is not None:
output = inner_tool_msg
# Extract fields from ToolMessage object when LangChain provides one.
# LangChain's _format_output wraps tool results into a ToolMessage
# with tool_call_id, name, status, and artifact — more complete than
# what kwargs alone provides.
if isinstance(output, ToolMessage):
tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = output.name or kwargs.get("name", "")
status = getattr(output, "status", "success") or "success"
content_str = output.content if isinstance(output.content, str) else str(output.content)
# Use model_dump() for checkpoint-aligned message content.
# Override tool_call_id if it was resolved from cache.
msg_content = output.model_dump()
if msg_content.get("tool_call_id") != tool_call_id:
msg_content["tool_call_id"] = tool_call_id
else:
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
status = "success"
content_str = str(output)
# Construct checkpoint-aligned dict when output is a plain string.
msg_content = ToolMessage(
content=content_str,
tool_call_id=tool_call_id or "",
name=tool_name,
status=status,
).model_dump()
# Trace event (always)
self._put(
event_type="tool_end",
category="trace",
content=content_str,
metadata={
"tool_name": tool_name,
"tool_call_id": tool_call_id,
"status": status,
},
)
# Message event: tool_result (checkpoint-aligned model_dump format)
self._put(
event_type="tool_result",
category="message",
content=msg_content,
metadata={"tool_name": tool_name, "status": status},
)
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
from langchain_core.messages import ToolMessage
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
# Trace event
self._put(
event_type="tool_error",
category="trace",
content=str(error),
metadata={
"tool_name": tool_name,
"tool_call_id": tool_call_id,
},
)
# Message event: tool_result with error status (checkpoint-aligned)
msg_content = ToolMessage(
content=str(error),
tool_call_id=tool_call_id or "",
name=tool_name,
status="error",
).model_dump()
self._put(
event_type="tool_result",
category="message",
content=msg_content,
metadata={"tool_name": tool_name, "status": "error"},
)
# -- Custom event callback --
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
from deerflow.runtime.serialization import serialize_lc_object
if name == "summarization":
data_dict = data if isinstance(data, dict) else {}
self._put(
event_type="summarization",
category="trace",
content=data_dict.get("summary", ""),
metadata={
"replaced_message_ids": data_dict.get("replaced_message_ids", []),
"replaced_count": data_dict.get("replaced_count", 0),
},
)
self._put(
event_type="middleware:summarize",
category="middleware",
content={"role": "system", "content": data_dict.get("summary", "")},
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
)
else:
event_data = serialize_lc_object(data) if not isinstance(data, dict) else data
self._put(
event_type=name,
category="trace",
metadata=event_data if isinstance(event_data, dict) else {"data": event_data},
)
# -- Internal methods --
def _put(self, *, event_type: str, category: str, content: str | dict = "", metadata: dict | None = None) -> None:
self._buffer.append(
{
"thread_id": self.thread_id,
"run_id": self.run_id,
"event_type": event_type,
"category": category,
"content": content,
"metadata": metadata or {},
"created_at": datetime.now(UTC).isoformat(),
}
)
if len(self._buffer) >= self._flush_threshold:
self._flush_sync()
def _flush_sync(self) -> None:
"""Best-effort flush of buffer to RunEventStore.
BaseCallbackHandler methods are synchronous. If an event loop is
running we schedule an async ``put_batch``; otherwise the events
stay in the buffer and are flushed later by the async ``flush()``
call in the worker's ``finally`` block.
"""
if not self._buffer:
return
# Skip if a flush is already in flight — avoids concurrent writes
# to the same SQLite file from multiple fire-and-forget tasks.
if self._pending_flush_tasks:
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# No event loop — keep events in buffer for later async flush.
return
batch = self._buffer.copy()
self._buffer.clear()
task = loop.create_task(self._flush_async(batch))
self._pending_flush_tasks.add(task)
task.add_done_callback(self._on_flush_done)
async def _flush_async(self, batch: list[dict]) -> None:
try:
await self._store.put_batch(batch)
except Exception:
logger.warning(
"Failed to flush %d events for run %s — returning to buffer",
len(batch),
self.run_id,
exc_info=True,
)
# Return failed events to buffer for retry on next flush
self._buffer = batch + self._buffer
def _on_flush_done(self, task: asyncio.Task) -> None:
self._pending_flush_tasks.discard(task)
if task.cancelled():
return
exc = task.exception()
if exc:
logger.warning("Journal flush task failed: %s", exc)
def _identify_caller(self, kwargs: dict) -> str:
for tag in kwargs.get("tags") or []:
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
return tag
# Default to lead_agent: the main agent graph does not inject
# callback tags, while subagents and middleware explicitly tag
# themselves.
return "lead_agent"
# -- Public methods (called by worker) --
def set_first_human_message(self, content: str) -> None:
"""Record the first human message for convenience fields."""
self._first_human_msg = content[:2000] if content else None
def record_middleware(self, tag: str, *, name: str, hook: str, action: str, changes: dict) -> None:
"""Record a middleware state-change event.
Called by middleware implementations when they perform a meaningful
state change (e.g., title generation, summarization, HITL approval).
Pure-observation middleware should not call this.
Args:
tag: Short identifier for the middleware (e.g., "title", "summarize",
"guardrail"). Used to form event_type="middleware:{tag}".
name: Full middleware class name.
hook: Lifecycle hook that triggered the action (e.g., "after_model").
action: Specific action performed (e.g., "generate_title").
changes: Dict describing the state changes made.
"""
self._put(
event_type=f"middleware:{tag}",
category="middleware",
content={"name": name, "hook": hook, "action": action, "changes": changes},
)
async def flush(self) -> None:
"""Force flush remaining buffer. Called in worker's finally block."""
if self._pending_flush_tasks:
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
while self._buffer:
batch = self._buffer[: self._flush_threshold]
del self._buffer[: self._flush_threshold]
try:
await self._store.put_batch(batch)
except Exception:
self._buffer = batch + self._buffer
raise
def get_completion_data(self) -> dict:
"""Return accumulated token and message data for run completion."""
return {
"total_input_tokens": self._total_input_tokens,
"total_output_tokens": self._total_output_tokens,
"total_tokens": self._total_tokens,
"llm_call_count": self._llm_call_count,
"lead_agent_tokens": self._lead_agent_tokens,
"subagent_tokens": self._subagent_tokens,
"middleware_tokens": self._middleware_tokens,
"message_count": self._msg_count,
"last_ai_message": self._last_ai_msg,
"first_human_message": self._first_human_msg,
}
@@ -2,11 +2,12 @@
from .manager import ConflictError, RunManager, RunRecord, UnsupportedStrategyError
from .schemas import DisconnectMode, RunStatus
from .worker import run_agent
from .worker import RunContext, run_agent
__all__ = [
"ConflictError",
"DisconnectMode",
"RunContext",
"RunManager",
"RunRecord",
"RunStatus",
@@ -1,4 +1,4 @@
"""In-memory run registry."""
"""In-memory run registry with optional persistent RunStore backing."""
from __future__ import annotations
@@ -7,9 +7,13 @@ import logging
import uuid
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import TYPE_CHECKING
from .schemas import DisconnectMode, RunStatus
if TYPE_CHECKING:
from deerflow.runtime.runs.store.base import RunStore
logger = logging.getLogger(__name__)
@@ -38,11 +42,44 @@ class RunRecord:
class RunManager:
"""In-memory run registry. All mutations are protected by an asyncio lock."""
"""In-memory run registry with optional persistent RunStore backing.
def __init__(self) -> None:
All mutations are protected by an asyncio lock. When a ``store`` is
provided, serializable metadata is also persisted to the store so
that run history survives process restarts.
"""
def __init__(self, store: RunStore | None = None) -> None:
self._runs: dict[str, RunRecord] = {}
self._lock = asyncio.Lock()
self._store = store
async def _persist_to_store(self, record: RunRecord, *, follow_up_to_run_id: str | None = None) -> None:
"""Best-effort persist run record to backing store."""
if self._store is None:
return
try:
await self._store.put(
record.run_id,
thread_id=record.thread_id,
assistant_id=record.assistant_id,
status=record.status.value,
multitask_strategy=record.multitask_strategy,
metadata=record.metadata or {},
kwargs=record.kwargs or {},
created_at=record.created_at,
follow_up_to_run_id=follow_up_to_run_id,
)
except Exception:
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
async def update_run_completion(self, run_id: str, **kwargs) -> None:
"""Persist token usage and completion data to the backing store."""
if self._store is not None:
try:
await self._store.update_run_completion(run_id, **kwargs)
except Exception:
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
async def create(
self,
@@ -53,6 +90,7 @@ class RunManager:
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
follow_up_to_run_id: str | None = None,
) -> RunRecord:
"""Create a new pending run and register it."""
run_id = str(uuid.uuid4())
@@ -71,6 +109,7 @@ class RunManager:
)
async with self._lock:
self._runs[run_id] = record
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record
@@ -96,6 +135,11 @@ class RunManager:
record.updated_at = _now_iso()
if error is not None:
record.error = error
if self._store is not None:
try:
await self._store.update_status(run_id, status.value, error=error)
except Exception:
logger.warning("Failed to persist status update for run %s", run_id, exc_info=True)
logger.info("Run %s -> %s", run_id, status.value)
async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool:
@@ -132,6 +176,7 @@ class RunManager:
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
follow_up_to_run_id: str | None = None,
) -> RunRecord:
"""Atomically check for inflight runs and create a new one.
@@ -185,6 +230,7 @@ class RunManager:
)
self._runs[run_id] = record
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record
@@ -0,0 +1,4 @@
from deerflow.runtime.runs.store.base import RunStore
from deerflow.runtime.runs.store.memory import MemoryRunStore
__all__ = ["MemoryRunStore", "RunStore"]
@@ -0,0 +1,96 @@
"""Abstract interface for run metadata storage.
RunManager depends on this interface. Implementations:
- MemoryRunStore: in-memory dict (development, tests)
- Future: RunRepository backed by SQLAlchemy ORM
All methods accept an optional user_id for user isolation.
When user_id is None, no user filtering is applied (single-user mode).
"""
from __future__ import annotations
import abc
from typing import Any
class RunStore(abc.ABC):
@abc.abstractmethod
async def put(
self,
run_id: str,
*,
thread_id: str,
assistant_id: str | None = None,
user_id: str | None = None,
status: str = "pending",
multitask_strategy: str = "reject",
metadata: dict[str, Any] | None = None,
kwargs: dict[str, Any] | None = None,
error: str | None = None,
created_at: str | None = None,
follow_up_to_run_id: str | None = None,
) -> None:
pass
@abc.abstractmethod
async def get(self, run_id: str) -> dict[str, Any] | None:
pass
@abc.abstractmethod
async def list_by_thread(
self,
thread_id: str,
*,
user_id: str | None = None,
limit: int = 100,
) -> list[dict[str, Any]]:
pass
@abc.abstractmethod
async def update_status(
self,
run_id: str,
status: str,
*,
error: str | None = None,
) -> None:
pass
@abc.abstractmethod
async def delete(self, run_id: str) -> None:
pass
@abc.abstractmethod
async def update_run_completion(
self,
run_id: str,
*,
status: str,
total_input_tokens: int = 0,
total_output_tokens: int = 0,
total_tokens: int = 0,
llm_call_count: int = 0,
lead_agent_tokens: int = 0,
subagent_tokens: int = 0,
middleware_tokens: int = 0,
message_count: int = 0,
last_ai_message: str | None = None,
first_human_message: str | None = None,
error: str | None = None,
) -> None:
pass
@abc.abstractmethod
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
pass
@abc.abstractmethod
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
"""Aggregate token usage for completed runs in a thread.
Returns a dict with keys: total_tokens, total_input_tokens,
total_output_tokens, total_runs, by_model (model_name → {tokens, runs}),
by_caller ({lead_agent, subagent, middleware}).
"""
pass
@@ -0,0 +1,100 @@
"""In-memory RunStore. Used when database.backend=memory (default) and in tests.
Equivalent to the original RunManager._runs dict behavior.
"""
from __future__ import annotations
from datetime import UTC, datetime
from typing import Any
from deerflow.runtime.runs.store.base import RunStore
class MemoryRunStore(RunStore):
def __init__(self) -> None:
self._runs: dict[str, dict[str, Any]] = {}
async def put(
self,
run_id,
*,
thread_id,
assistant_id=None,
user_id=None,
status="pending",
multitask_strategy="reject",
metadata=None,
kwargs=None,
error=None,
created_at=None,
follow_up_to_run_id=None,
):
now = datetime.now(UTC).isoformat()
self._runs[run_id] = {
"run_id": run_id,
"thread_id": thread_id,
"assistant_id": assistant_id,
"user_id": user_id,
"status": status,
"multitask_strategy": multitask_strategy,
"metadata": metadata or {},
"kwargs": kwargs or {},
"error": error,
"follow_up_to_run_id": follow_up_to_run_id,
"created_at": created_at or now,
"updated_at": now,
}
async def get(self, run_id):
return self._runs.get(run_id)
async def list_by_thread(self, thread_id, *, user_id=None, limit=100):
results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (user_id is None or r.get("user_id") == user_id)]
results.sort(key=lambda r: r["created_at"], reverse=True)
return results[:limit]
async def update_status(self, run_id, status, *, error=None):
if run_id in self._runs:
self._runs[run_id]["status"] = status
if error is not None:
self._runs[run_id]["error"] = error
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def delete(self, run_id):
self._runs.pop(run_id, None)
async def update_run_completion(self, run_id, *, status, **kwargs):
if run_id in self._runs:
self._runs[run_id]["status"] = status
for key, value in kwargs.items():
if value is not None:
self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def list_pending(self, *, before=None):
now = before or datetime.now(UTC).isoformat()
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]
results.sort(key=lambda r: r["created_at"])
return results
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in ("success", "error")]
by_model: dict[str, dict] = {}
for r in completed:
model = r.get("model_name") or "unknown"
entry = by_model.setdefault(model, {"tokens": 0, "runs": 0})
entry["tokens"] += r.get("total_tokens", 0)
entry["runs"] += 1
return {
"total_tokens": sum(r.get("total_tokens", 0) for r in completed),
"total_input_tokens": sum(r.get("total_input_tokens", 0) for r in completed),
"total_output_tokens": sum(r.get("total_output_tokens", 0) for r in completed),
"total_runs": len(completed),
"by_model": by_model,
"by_caller": {
"lead_agent": sum(r.get("lead_agent_tokens", 0) for r in completed),
"subagent": sum(r.get("subagent_tokens", 0) for r in completed),
"middleware": sum(r.get("middleware_tokens", 0) for r in completed),
},
}
@@ -19,8 +19,14 @@ import asyncio
import copy
import inspect
import logging
from typing import Any, Literal
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
from langchain_core.messages import HumanMessage
from deerflow.config.app_config import AppConfig
from deerflow.config.deer_flow_context import DeerFlowContext
from deerflow.runtime.serialization import serialize
from deerflow.runtime.stream_bridge import StreamBridge
@@ -33,13 +39,30 @@ logger = logging.getLogger(__name__)
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}
@dataclass(frozen=True)
class RunContext:
"""Infrastructure dependencies for a single agent run.
Groups checkpointer, store, and persistence-related singletons so that
``run_agent`` (and any future callers) receive one object instead of a
growing list of keyword arguments.
"""
checkpointer: Any
store: Any | None = field(default=None)
event_store: Any | None = field(default=None)
run_events_config: Any | None = field(default=None)
thread_store: Any | None = field(default=None)
follow_up_to_run_id: str | None = field(default=None)
app_config: AppConfig | None = field(default=None)
async def run_agent(
bridge: StreamBridge,
run_manager: RunManager,
record: RunRecord,
*,
checkpointer: Any,
store: Any | None = None,
ctx: RunContext,
agent_factory: Any,
graph_input: dict,
config: dict,
@@ -50,6 +73,14 @@ async def run_agent(
) -> None:
"""Execute an agent in the background, publishing events to *bridge*."""
# Unpack infrastructure dependencies from RunContext.
checkpointer = ctx.checkpointer
store = ctx.store
event_store = ctx.event_store
run_events_config = ctx.run_events_config
thread_store = ctx.thread_store
follow_up_to_run_id = ctx.follow_up_to_run_id
run_id = record.run_id
thread_id = record.thread_id
requested_modes: set[str] = set(stream_modes or ["values"])
@@ -57,6 +88,10 @@ async def run_agent(
pre_run_snapshot: dict[str, Any] | None = None
snapshot_capture_failed = False
journal = None
journal = None
# Track whether "events" was requested but skipped
if "events" in requested_modes:
logger.info(
@@ -65,6 +100,38 @@ async def run_agent(
)
try:
# Initialize RunJournal + write human_message event.
# These are inside the try block so any exception (e.g. a DB
# error writing the event) flows through the except/finally
# path that publishes an "end" event to the SSE bridge —
# otherwise a failure here would leave the stream hanging
# with no terminator.
if event_store is not None:
from deerflow.runtime.journal import RunJournal
journal = RunJournal(
run_id=run_id,
thread_id=thread_id,
event_store=event_store,
track_token_usage=getattr(run_events_config, "track_token_usage", True),
)
human_msg = _extract_human_message(graph_input)
if human_msg is not None:
msg_metadata = {}
if follow_up_to_run_id:
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
await event_store.put(
thread_id=thread_id,
run_id=run_id,
event_type="human_message",
category="message",
content=human_msg.model_dump(),
metadata=msg_metadata or None,
)
content = human_msg.content
journal.set_first_human_message(content if isinstance(content, str) else str(content))
# 1. Mark running
await run_manager.set_status(run_id, RunStatus.running)
@@ -98,17 +165,21 @@ async def run_agent(
# 3. Build the agent
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
# Inject runtime context so middlewares can access thread_id
# (langgraph-cli does this automatically; we must do it manually)
runtime = Runtime(context={"thread_id": thread_id}, store=store)
# If the caller already set a ``context`` key (LangGraph >= 0.6.0
# prefers it over ``configurable`` for thread-level data), make
# sure ``thread_id`` is available there too.
if "context" in config and isinstance(config["context"], dict):
config["context"].setdefault("thread_id", thread_id)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
# Construct typed context for the agent run.
# LangGraph's astream(context=...) injects this into Runtime.context
# so middleware/tools can access it via resolve_context().
if ctx.app_config is None:
raise RuntimeError("RunContext.app_config is required — Gateway must populate it via get_run_context")
deer_flow_context = DeerFlowContext(
app_config=ctx.app_config,
thread_id=thread_id,
)
# Inject RunJournal as a LangChain callback handler.
# on_llm_end captures token usage; on_chain_start/end captures lifecycle.
if journal is not None:
config.setdefault("callbacks", []).append(journal)
runnable_config = RunnableConfig(**config)
agent = agent_factory(config=runnable_config)
@@ -155,7 +226,7 @@ async def run_agent(
if len(lg_modes) == 1 and not stream_subgraphs:
# Single mode, no subgraphs: astream yields raw chunks
single_mode = lg_modes[0]
async for chunk in agent.astream(graph_input, config=runnable_config, stream_mode=single_mode):
async for chunk in agent.astream(graph_input, config=runnable_config, context=deer_flow_context, stream_mode=single_mode):
if record.abort_event.is_set():
logger.info("Run %s abort requested — stopping", run_id)
break
@@ -166,6 +237,7 @@ async def run_agent(
async for item in agent.astream(
graph_input,
config=runnable_config,
context=deer_flow_context,
stream_mode=lg_modes,
subgraphs=stream_subgraphs,
):
@@ -236,6 +308,41 @@ async def run_agent(
)
finally:
# Flush any buffered journal events and persist completion data
if journal is not None:
try:
await journal.flush()
except Exception:
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
try:
# Persist token usage + convenience fields to RunStore
completion = journal.get_completion_data()
await run_manager.update_run_completion(run_id, status=record.status.value, **completion)
except Exception:
logger.warning("Failed to persist run completion for %s (non-fatal)", run_id, exc_info=True)
# Sync title from checkpoint to threads_meta.display_name
if checkpointer is not None and thread_store is not None:
try:
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
if ckpt_tuple is not None:
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
title = ckpt.get("channel_values", {}).get("title")
if title:
await thread_store.update_display_name(thread_id, title)
except Exception:
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
# Update threads_meta status based on run outcome
if thread_store is not None:
try:
final_status = "idle" if record.status == RunStatus.success else record.status.value
await thread_store.update_status(thread_id, final_status)
except Exception:
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
await bridge.publish_end(run_id)
asyncio.create_task(bridge.cleanup(run_id, delay=60))
@@ -355,6 +462,31 @@ def _lg_mode_to_sse_event(mode: str) -> str:
return mode
def _extract_human_message(graph_input: dict) -> HumanMessage | None:
"""Extract or construct a HumanMessage from graph_input for event recording.
Returns a LangChain HumanMessage so callers can use .model_dump() to get
the checkpoint-aligned serialization format.
"""
from langchain_core.messages import HumanMessage
messages = graph_input.get("messages")
if not messages:
return None
last = messages[-1] if isinstance(messages, list) else messages
if isinstance(last, HumanMessage):
return last
if isinstance(last, str):
return HumanMessage(content=last) if last else None
if hasattr(last, "content"):
content = last.content
return HumanMessage(content=content)
if isinstance(last, dict):
content = last.get("content", "")
return HumanMessage(content=content) if content else None
return None
def _unpack_stream_item(
item: Any,
lg_modes: list[str],
@@ -23,7 +23,7 @@ from collections.abc import AsyncIterator
from langgraph.store.base import BaseStore
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.runtime.store.provider import POSTGRES_CONN_REQUIRED, POSTGRES_STORE_INSTALL, SQLITE_STORE_INSTALL, ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@@ -86,28 +86,26 @@ async def _async_store(config) -> AsyncIterator[BaseStore]:
@contextlib.asynccontextmanager
async def make_store() -> AsyncIterator[BaseStore]:
async def make_store(app_config: AppConfig) -> AsyncIterator[BaseStore]:
"""Async context manager that yields a Store whose backend matches the
configured checkpointer.
Reads from the same ``checkpointer`` section of *config.yaml* used by
:func:`deerflow.agents.checkpointer.async_provider.make_checkpointer` so
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer` so
that both singletons always use the same persistence technology::
async with make_store() as store:
async with make_store(app_config) as store:
app.state.store = store
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
``checkpointer`` section is configured (emits a WARNING in that case).
"""
config = get_app_config()
if config.checkpointer is None:
if app_config.checkpointer is None:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
yield InMemoryStore()
return
async with _async_store(config.checkpointer) as store:
async with _async_store(app_config.checkpointer) as store:
yield store
@@ -26,7 +26,7 @@ from collections.abc import Iterator
from langgraph.store.base import BaseStore
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@@ -100,7 +100,7 @@ _store: BaseStore | None = None
_store_ctx = None # open context manager keeping the connection alive
def get_store() -> BaseStore:
def get_store(app_config: AppConfig) -> BaseStore:
"""Return the global sync Store singleton, creating it on first call.
Returns an :class:`~langgraph.store.memory.InMemoryStore` when no
@@ -115,19 +115,10 @@ def get_store() -> BaseStore:
if _store is not None:
return _store
# Lazily load app config, mirroring the checkpointer singleton pattern so
# that tests that set the global checkpointer config explicitly remain isolated.
from deerflow.config.app_config import _app_config
from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config()
if config is None and _app_config is None:
try:
get_app_config()
except FileNotFoundError:
pass
config = get_checkpointer_config()
# See matching comment in checkpointer/provider.py: a missing config.yaml
# is a deployment error, not a cue to silently pick InMemoryStore. Only
# the explicit "no checkpointer section" path falls through to memory.
config = app_config.checkpointer
if config is None:
from langgraph.store.memory import InMemoryStore
@@ -163,26 +154,25 @@ def reset_store() -> None:
@contextlib.contextmanager
def store_context() -> Iterator[BaseStore]:
def store_context(app_config: AppConfig) -> Iterator[BaseStore]:
"""Sync context manager that yields a Store and cleans up on exit.
Unlike :func:`get_store`, this does **not** cache the instance — each
``with`` block creates and destroys its own connection. Use it in CLI
scripts or tests where you want deterministic cleanup::
with store_context() as store:
with store_context(app_config) as store:
store.put(("threads",), thread_id, {...})
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
checkpointer is configured in *config.yaml*.
"""
config = get_app_config()
if config.checkpointer is None:
if app_config.checkpointer is None:
from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
yield InMemoryStore()
return
with _sync_store_cm(config.checkpointer) as store:
with _sync_store_cm(app_config.checkpointer) as store:
yield store
@@ -1,7 +1,7 @@
"""Async stream bridge factory.
Provides an **async context manager** aligned with
:func:`deerflow.agents.checkpointer.async_provider.make_checkpointer`.
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer`.
Usage (e.g. FastAPI lifespan)::
@@ -17,7 +17,7 @@ import contextlib
import logging
from collections.abc import AsyncIterator
from deerflow.config.stream_bridge_config import get_stream_bridge_config
from deerflow.config.app_config import AppConfig
from .base import StreamBridge
@@ -25,14 +25,13 @@ logger = logging.getLogger(__name__)
@contextlib.asynccontextmanager
async def make_stream_bridge(config=None) -> AsyncIterator[StreamBridge]:
async def make_stream_bridge(app_config: AppConfig) -> AsyncIterator[StreamBridge]:
"""Async context manager that yields a :class:`StreamBridge`.
Falls back to :class:`MemoryStreamBridge` when no configuration is
provided and nothing is set globally.
Falls back to :class:`MemoryStreamBridge` when no ``stream_bridge``
section is configured.
"""
if config is None:
config = get_stream_bridge_config()
config = app_config.stream_bridge
if config is None or config.type == "memory":
from deerflow.runtime.stream_bridge.memory import MemoryStreamBridge
@@ -0,0 +1,167 @@
"""Request-scoped user context for user-based authorization.
This module holds a :class:`~contextvars.ContextVar` that the gateway's
auth middleware sets after a successful authentication. Repository
methods read the contextvar via a sentinel default parameter, letting
routers stay free of ``user_id`` boilerplate.
Three-state semantics for the repository ``user_id`` parameter (the
consumer side of this module lives in ``deerflow.persistence.*``):
- ``_AUTO`` (module-private sentinel, default): read from contextvar;
raise :class:`RuntimeError` if unset.
- Explicit ``str``: use the provided value, overriding contextvar.
- Explicit ``None``: no WHERE clause — used only by migration scripts
and admin CLIs that intentionally bypass isolation.
Dependency direction
--------------------
``persistence`` (lower layer) reads from this module; ``gateway.auth``
(higher layer) writes to it. ``CurrentUser`` is defined here as a
:class:`typing.Protocol` so that ``persistence`` never needs to import
the concrete ``User`` class from ``gateway.auth.models``. Any object
with an ``.id: str`` attribute structurally satisfies the protocol.
Asyncio semantics
-----------------
``ContextVar`` is task-local under asyncio, not thread-local. Each
FastAPI request runs in its own task, so the context is naturally
isolated. ``asyncio.create_task`` and ``asyncio.to_thread`` inherit the
parent task's context, which is typically the intended behaviour; if
a background task must *not* see the foreground user, wrap it with
``contextvars.copy_context()`` to get a clean copy.
"""
from __future__ import annotations
from contextvars import ContextVar, Token
from typing import Final, Protocol, runtime_checkable
@runtime_checkable
class CurrentUser(Protocol):
"""Structural type for the current authenticated user.
Any object with an ``.id: str`` attribute satisfies this protocol.
Concrete implementations live in ``app.gateway.auth.models.User``.
"""
id: str
_current_user: Final[ContextVar[CurrentUser | None]] = ContextVar("deerflow_current_user", default=None)
def set_current_user(user: CurrentUser) -> Token[CurrentUser | None]:
"""Set the current user for this async task.
Returns a reset token that should be passed to
:func:`reset_current_user` in a ``finally`` block to restore the
previous context.
"""
return _current_user.set(user)
def reset_current_user(token: Token[CurrentUser | None]) -> None:
"""Restore the context to the state captured by ``token``."""
_current_user.reset(token)
def get_current_user() -> CurrentUser | None:
"""Return the current user, or ``None`` if unset.
Safe to call in any context. Used by code paths that can proceed
without a user (e.g. migration scripts, public endpoints).
"""
return _current_user.get()
def require_current_user() -> CurrentUser:
"""Return the current user, or raise :class:`RuntimeError`.
Used by repository code that must not be called outside a
request-authenticated context. The error message is phrased so
that a caller debugging a stack trace can locate the offending
code path.
"""
user = _current_user.get()
if user is None:
raise RuntimeError("repository accessed without user context")
return user
# ---------------------------------------------------------------------------
# Effective user_id helpers (filesystem isolation)
# ---------------------------------------------------------------------------
DEFAULT_USER_ID: Final[str] = "default"
def get_effective_user_id() -> str:
"""Return the current user's id as a string, or DEFAULT_USER_ID if unset.
Unlike :func:`require_current_user` this never raises — it is designed
for filesystem-path resolution where a valid user bucket is always needed.
"""
user = _current_user.get()
if user is None:
return DEFAULT_USER_ID
return str(user.id)
# ---------------------------------------------------------------------------
# Sentinel-based user_id resolution
# ---------------------------------------------------------------------------
#
# Repository methods accept a ``user_id`` keyword-only argument that
# defaults to ``AUTO``. The three possible values drive distinct
# behaviours; see the docstring on :func:`resolve_user_id`.
class _AutoSentinel:
"""Singleton marker meaning 'resolve user_id from contextvar'."""
_instance: _AutoSentinel | None = None
def __new__(cls) -> _AutoSentinel:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __repr__(self) -> str:
return "<AUTO>"
AUTO: Final[_AutoSentinel] = _AutoSentinel()
def resolve_user_id(
value: str | None | _AutoSentinel,
*,
method_name: str = "repository method",
) -> str | None:
"""Resolve the user_id parameter passed to a repository method.
Three-state semantics:
- :data:`AUTO` (default): read from contextvar; raise
:class:`RuntimeError` if no user is in context. This is the
common case for request-scoped calls.
- Explicit ``str``: use the provided id verbatim, overriding any
contextvar value. Useful for tests and admin-override flows.
- Explicit ``None``: no filter — the repository should skip the
user_id WHERE clause entirely. Reserved for migration scripts
and CLI tools that intentionally bypass isolation.
"""
if isinstance(value, _AutoSentinel):
user = _current_user.get()
if user is None:
raise RuntimeError(f"{method_name} called with user_id=AUTO but no user context is set; pass an explicit user_id, set the contextvar via auth middleware, or opt out with user_id=None for migration/CLI paths.")
# Coerce to ``str`` at the boundary: ``User.id`` is typed as
# ``UUID`` for the API surface, but the persistence layer
# stores ``user_id`` as ``String(64)`` and aiosqlite cannot
# bind a raw UUID object to a VARCHAR column ("type 'UUID' is
# not supported"). Honour the documented return type here
# rather than ripple a type change through every caller.
return str(user.id)
return value