mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 08:25:57 +00:00
refactor(harness): remove old persistence layer from deerflow package
Remove the following deprecated modules: - deerflow/persistence/ - old SQL persistence models and repositories - deerflow/config/checkpointer_config.py - checkpointer configuration - deerflow/config/database_config.py - database configuration - deerflow/runtime/checkpointer/ - checkpointer providers - deerflow/runtime/store/ - store providers - deerflow/runtime/events/ - event store implementations - deerflow/runtime/journal.py - run journal These components are replaced by the new storage layer in app/infra/. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,9 +0,0 @@
|
||||
from .async_provider import make_checkpointer
|
||||
from .provider import checkpointer_context, get_checkpointer, reset_checkpointer
|
||||
|
||||
__all__ = [
|
||||
"get_checkpointer",
|
||||
"reset_checkpointer",
|
||||
"checkpointer_context",
|
||||
"make_checkpointer",
|
||||
]
|
||||
@@ -1,159 +0,0 @@
|
||||
"""Async checkpointer factory.
|
||||
|
||||
Provides an **async context manager** for long-running async servers that need
|
||||
proper resource cleanup.
|
||||
|
||||
Supported backends: memory, sqlite, postgres.
|
||||
|
||||
Usage (e.g. FastAPI lifespan)::
|
||||
|
||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||
|
||||
async with make_checkpointer() as checkpointer:
|
||||
app.state.checkpointer = checkpointer # InMemorySaver if not configured
|
||||
|
||||
For sync usage see :mod:`deerflow.runtime.checkpointer.provider`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.runtime.checkpointer.provider import (
|
||||
POSTGRES_CONN_REQUIRED,
|
||||
POSTGRES_INSTALL,
|
||||
SQLITE_INSTALL,
|
||||
)
|
||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Async factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
|
||||
"""Async context manager that constructs and tears down a checkpointer."""
|
||||
if config.type == "memory":
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
if config.type == "sqlite":
|
||||
try:
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(SQLITE_INSTALL) from exc
|
||||
|
||||
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
|
||||
await asyncio.to_thread(ensure_sqlite_parent_dir, conn_str)
|
||||
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
|
||||
await saver.setup()
|
||||
yield saver
|
||||
return
|
||||
|
||||
if config.type == "postgres":
|
||||
try:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_INSTALL) from exc
|
||||
|
||||
if not config.connection_string:
|
||||
raise ValueError(POSTGRES_CONN_REQUIRED)
|
||||
|
||||
async with AsyncPostgresSaver.from_conn_string(config.connection_string) as saver:
|
||||
await saver.setup()
|
||||
yield saver
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown checkpointer type: {config.type!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public async context manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpointer]:
|
||||
"""Async context manager that constructs a checkpointer from unified DatabaseConfig."""
|
||||
if db_config.backend == "memory":
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
if db_config.backend == "sqlite":
|
||||
try:
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(SQLITE_INSTALL) from exc
|
||||
|
||||
conn_str = db_config.checkpointer_sqlite_path
|
||||
ensure_sqlite_parent_dir(conn_str)
|
||||
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
|
||||
await saver.setup()
|
||||
yield saver
|
||||
return
|
||||
|
||||
if db_config.backend == "postgres":
|
||||
try:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_INSTALL) from exc
|
||||
|
||||
if not db_config.postgres_url:
|
||||
raise ValueError("database.postgres_url is required for the postgres backend")
|
||||
|
||||
async with AsyncPostgresSaver.from_conn_string(db_config.postgres_url) as saver:
|
||||
await saver.setup()
|
||||
yield saver
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown database backend: {db_config.backend!r}")
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_checkpointer() -> AsyncIterator[Checkpointer]:
|
||||
"""Async context manager that yields a checkpointer for the caller's lifetime.
|
||||
Resources are opened on enter and closed on exit -- no global state::
|
||||
|
||||
async with make_checkpointer() as checkpointer:
|
||||
app.state.checkpointer = checkpointer
|
||||
|
||||
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
|
||||
Priority:
|
||||
1. Legacy ``checkpointer:`` config section (backward compatible)
|
||||
2. Unified ``database:`` config section
|
||||
3. Default InMemorySaver
|
||||
"""
|
||||
|
||||
config = get_app_config()
|
||||
|
||||
# Legacy: standalone checkpointer config takes precedence
|
||||
if config.checkpointer is not None:
|
||||
async with _async_checkpointer(config.checkpointer) as saver:
|
||||
yield saver
|
||||
return
|
||||
|
||||
# Unified database config
|
||||
db_config = getattr(config, "database", None)
|
||||
if db_config is not None and db_config.backend != "memory":
|
||||
async with _async_checkpointer_from_database(db_config) as saver:
|
||||
yield saver
|
||||
return
|
||||
|
||||
# Default: in-memory
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
@@ -1,191 +0,0 @@
|
||||
"""Sync checkpointer factory.
|
||||
|
||||
Provides a **sync singleton** and a **sync context manager** for LangGraph
|
||||
graph compilation and CLI tools.
|
||||
|
||||
Supported backends: memory, sqlite, postgres.
|
||||
|
||||
Usage::
|
||||
|
||||
from deerflow.runtime.checkpointer.provider import get_checkpointer, checkpointer_context
|
||||
|
||||
# Singleton — reused across calls, closed on process exit
|
||||
cp = get_checkpointer()
|
||||
|
||||
# One-shot — fresh connection, closed on block exit
|
||||
with checkpointer_context() as cp:
|
||||
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig
|
||||
from deerflow.runtime.store._sqlite_utils import resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error message constants — imported by aio.provider too
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SQLITE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite checkpointer. Install it with: uv add langgraph-checkpoint-sqlite"
|
||||
POSTGRES_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL checkpointer. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
|
||||
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
|
||||
"""Context manager that creates and tears down a sync checkpointer.
|
||||
|
||||
Returns a configured ``Checkpointer`` instance. Resource cleanup for any
|
||||
underlying connections or pools is handled by higher-level helpers in
|
||||
this module (such as the singleton factory or context manager); this
|
||||
function does not return a separate cleanup callback.
|
||||
"""
|
||||
if config.type == "memory":
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
if config.type == "sqlite":
|
||||
try:
|
||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(SQLITE_INSTALL) from exc
|
||||
|
||||
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
|
||||
with SqliteSaver.from_conn_string(conn_str) as saver:
|
||||
saver.setup()
|
||||
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
|
||||
yield saver
|
||||
return
|
||||
|
||||
if config.type == "postgres":
|
||||
try:
|
||||
from langgraph.checkpoint.postgres import PostgresSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_INSTALL) from exc
|
||||
|
||||
if not config.connection_string:
|
||||
raise ValueError(POSTGRES_CONN_REQUIRED)
|
||||
|
||||
with PostgresSaver.from_conn_string(config.connection_string) as saver:
|
||||
saver.setup()
|
||||
logger.info("Checkpointer: using PostgresSaver")
|
||||
yield saver
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown checkpointer type: {config.type!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_checkpointer: Checkpointer | None = None
|
||||
_checkpointer_ctx = None # open context manager keeping the connection alive
|
||||
|
||||
|
||||
def get_checkpointer() -> Checkpointer:
|
||||
"""Return the global sync checkpointer singleton, creating it on first call.
|
||||
|
||||
Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
|
||||
Raises:
|
||||
ImportError: If the required package for the configured backend is not installed.
|
||||
ValueError: If ``connection_string`` is missing for a backend that requires it.
|
||||
"""
|
||||
global _checkpointer, _checkpointer_ctx
|
||||
|
||||
if _checkpointer is not None:
|
||||
return _checkpointer
|
||||
|
||||
# Ensure app config is loaded before checking checkpointer config
|
||||
# This prevents returning InMemorySaver when config.yaml actually has a checkpointer section
|
||||
# but hasn't been loaded yet
|
||||
from deerflow.config.app_config import _app_config
|
||||
from deerflow.config.checkpointer_config import get_checkpointer_config
|
||||
|
||||
config = get_checkpointer_config()
|
||||
|
||||
if config is None and _app_config is None:
|
||||
# Only load app config lazily when neither the app config nor an explicit
|
||||
# checkpointer config has been initialized yet. This keeps tests that
|
||||
# intentionally set the global checkpointer config isolated from any
|
||||
# ambient config.yaml on disk.
|
||||
try:
|
||||
get_app_config()
|
||||
except FileNotFoundError:
|
||||
# In test environments without config.yaml, this is expected.
|
||||
pass
|
||||
config = get_checkpointer_config()
|
||||
if config is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
|
||||
_checkpointer = InMemorySaver()
|
||||
return _checkpointer
|
||||
|
||||
_checkpointer_ctx = _sync_checkpointer_cm(config)
|
||||
_checkpointer = _checkpointer_ctx.__enter__()
|
||||
|
||||
return _checkpointer
|
||||
|
||||
|
||||
def reset_checkpointer() -> None:
|
||||
"""Reset the sync singleton, forcing recreation on the next call.
|
||||
|
||||
Closes any open backend connections and clears the cached instance.
|
||||
Useful in tests or after a configuration change.
|
||||
"""
|
||||
global _checkpointer, _checkpointer_ctx
|
||||
if _checkpointer_ctx is not None:
|
||||
try:
|
||||
_checkpointer_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error during checkpointer cleanup", exc_info=True)
|
||||
_checkpointer_ctx = None
|
||||
_checkpointer = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync context manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def checkpointer_context() -> Iterator[Checkpointer]:
|
||||
"""Sync context manager that yields a checkpointer and cleans up on exit.
|
||||
|
||||
Unlike :func:`get_checkpointer`, this does **not** cache the instance —
|
||||
each ``with`` block creates and destroys its own connection. Use it in
|
||||
CLI scripts or tests where you want deterministic cleanup::
|
||||
|
||||
with checkpointer_context() as cp:
|
||||
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
|
||||
|
||||
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
"""
|
||||
|
||||
config = get_app_config()
|
||||
if config.checkpointer is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
with _sync_checkpointer_cm(config.checkpointer) as saver:
|
||||
yield saver
|
||||
@@ -1,4 +0,0 @@
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
__all__ = ["MemoryRunEventStore", "RunEventStore"]
|
||||
@@ -1,26 +0,0 @@
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
|
||||
def make_run_event_store(config=None) -> RunEventStore:
|
||||
"""Create a RunEventStore based on run_events.backend configuration."""
|
||||
if config is None or config.backend == "memory":
|
||||
return MemoryRunEventStore()
|
||||
if config.backend == "db":
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
# database.backend=memory but run_events.backend=db -> fallback
|
||||
return MemoryRunEventStore()
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
return DbRunEventStore(sf, max_trace_content=config.max_trace_content)
|
||||
if config.backend == "jsonl":
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
return JsonlRunEventStore()
|
||||
raise ValueError(f"Unknown run_events backend: {config.backend!r}")
|
||||
|
||||
|
||||
__all__ = ["MemoryRunEventStore", "RunEventStore", "make_run_event_store"]
|
||||
@@ -1,109 +0,0 @@
|
||||
"""Abstract interface for run event storage.
|
||||
|
||||
RunEventStore is the unified storage interface for run event streams.
|
||||
Messages (frontend display) and execution traces (debugging/audit) go
|
||||
through the same interface, distinguished by the ``category`` field.
|
||||
|
||||
Implementations:
|
||||
- MemoryRunEventStore: in-memory dict (development, tests)
|
||||
- Future: DB-backed store (SQLAlchemy ORM), JSONL file store
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
|
||||
class RunEventStore(abc.ABC):
|
||||
"""Run event stream storage interface.
|
||||
|
||||
All implementations must guarantee:
|
||||
1. put() events are retrievable in subsequent queries
|
||||
2. seq is strictly increasing within the same thread
|
||||
3. list_messages() only returns category="message" events
|
||||
4. list_events() returns all events for the specified run
|
||||
5. Returned dicts match the RunEvent field structure
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def put(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
event_type: str,
|
||||
category: str,
|
||||
content: str | dict = "",
|
||||
metadata: dict | None = None,
|
||||
created_at: str | None = None,
|
||||
) -> dict:
|
||||
"""Write an event, auto-assign seq, return the complete record."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def put_batch(self, events: list[dict]) -> list[dict]:
|
||||
"""Batch-write events. Used by RunJournal flush buffer.
|
||||
|
||||
Each dict's keys match put()'s keyword arguments.
|
||||
Returns complete records with seq assigned.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[dict]:
|
||||
"""Return displayable messages (category=message) for a thread, ordered by seq ascending.
|
||||
|
||||
Supports bidirectional cursor pagination:
|
||||
- before_seq: return the last ``limit`` records with seq < before_seq (ascending)
|
||||
- after_seq: return the first ``limit`` records with seq > after_seq (ascending)
|
||||
- neither: return the latest ``limit`` records (ascending)
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
event_types: list[str] | None = None,
|
||||
limit: int = 500,
|
||||
) -> list[dict]:
|
||||
"""Return the full event stream for a run, ordered by seq ascending.
|
||||
|
||||
Optionally filter by event_types.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[dict]:
|
||||
"""Return displayable messages (category=message) for a specific run, ordered by seq ascending.
|
||||
|
||||
Supports bidirectional cursor pagination:
|
||||
- after_seq: return the first ``limit`` records with seq > after_seq (ascending)
|
||||
- before_seq: return the last ``limit`` records with seq < before_seq (ascending)
|
||||
- neither: return the latest ``limit`` records (ascending)
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def count_messages(self, thread_id: str) -> int:
|
||||
"""Count displayable messages (category=message) in a thread."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_by_thread(self, thread_id: str) -> int:
|
||||
"""Delete all events for a thread. Return the number of deleted events."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_by_run(self, thread_id: str, run_id: str) -> int:
|
||||
"""Delete all events for a specific run. Return the number of deleted events."""
|
||||
@@ -1,286 +0,0 @@
|
||||
"""SQLAlchemy-backed RunEventStore implementation.
|
||||
|
||||
Persists events to the ``run_events`` table. Trace content is truncated
|
||||
at ``max_trace_content`` bytes to avoid bloating the database.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DbRunEventStore(RunEventStore):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, max_trace_content: int = 10240):
|
||||
self._sf = session_factory
|
||||
self._max_trace_content = max_trace_content
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: RunEventRow) -> dict:
|
||||
d = row.to_dict()
|
||||
d["metadata"] = d.pop("event_metadata", {})
|
||||
val = d.get("created_at")
|
||||
if isinstance(val, datetime):
|
||||
d["created_at"] = val.isoformat()
|
||||
d.pop("id", None)
|
||||
# Restore dict content that was JSON-serialized on write
|
||||
raw = d.get("content", "")
|
||||
if isinstance(raw, str) and d.get("metadata", {}).get("content_is_dict"):
|
||||
try:
|
||||
d["content"] = json.loads(raw)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Content looked like JSON (content_is_dict flag) but failed to parse;
|
||||
# keep the raw string as-is.
|
||||
logger.debug("Failed to deserialize content as JSON for event seq=%s", d.get("seq"))
|
||||
return d
|
||||
|
||||
def _truncate_trace(self, category: str, content: str | dict, metadata: dict | None) -> tuple[str | dict, dict]:
|
||||
if category == "trace":
|
||||
text = json.dumps(content, default=str, ensure_ascii=False) if isinstance(content, dict) else content
|
||||
encoded = text.encode("utf-8")
|
||||
if len(encoded) > self._max_trace_content:
|
||||
# Truncate by bytes, then decode back (may cut a multi-byte char, so use errors="ignore")
|
||||
content = encoded[: self._max_trace_content].decode("utf-8", errors="ignore")
|
||||
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
|
||||
return content, metadata or {}
|
||||
|
||||
@staticmethod
|
||||
def _user_id_from_context() -> str | None:
|
||||
"""Soft read of user_id from contextvar for write paths.
|
||||
|
||||
Returns ``None`` (no filter / no stamp) if contextvar is unset,
|
||||
which is the expected case for background worker writes. HTTP
|
||||
request writes will have the contextvar set by auth middleware
|
||||
and get their user_id stamped automatically.
|
||||
|
||||
Coerces ``user.id`` to ``str`` at the boundary: ``User.id`` is
|
||||
typed as ``UUID`` by the auth layer, but ``run_events.user_id``
|
||||
is ``VARCHAR(64)`` and aiosqlite cannot bind a raw UUID object
|
||||
to a VARCHAR column ("type 'UUID' is not supported") — the
|
||||
INSERT would silently roll back and the worker would hang.
|
||||
"""
|
||||
user = get_current_user()
|
||||
return str(user.id) if user is not None else None
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
|
||||
"""Write a single event — low-frequency path only.
|
||||
|
||||
This opens a dedicated transaction with a FOR UPDATE lock to
|
||||
assign a monotonic *seq*. For high-throughput writes use
|
||||
:meth:`put_batch`, which acquires the lock once for the whole
|
||||
batch. Currently the only caller is ``worker.run_agent`` for
|
||||
the initial ``human_message`` event (once per run).
|
||||
"""
|
||||
content, metadata = self._truncate_trace(category, content, metadata)
|
||||
if isinstance(content, dict):
|
||||
db_content = json.dumps(content, default=str, ensure_ascii=False)
|
||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||
else:
|
||||
db_content = content
|
||||
user_id = self._user_id_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||
seq = (max_seq or 0) + 1
|
||||
row = RunEventRow(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
user_id=user_id,
|
||||
event_type=event_type,
|
||||
category=category,
|
||||
content=db_content,
|
||||
event_metadata=metadata,
|
||||
seq=seq,
|
||||
created_at=datetime.fromisoformat(created_at) if created_at else datetime.now(UTC),
|
||||
)
|
||||
session.add(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def put_batch(self, events):
|
||||
if not events:
|
||||
return []
|
||||
user_id = self._user_id_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||
thread_id = events[0]["thread_id"]
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||
seq = max_seq or 0
|
||||
rows = []
|
||||
for e in events:
|
||||
seq += 1
|
||||
content = e.get("content", "")
|
||||
category = e.get("category", "trace")
|
||||
metadata = e.get("metadata")
|
||||
content, metadata = self._truncate_trace(category, content, metadata)
|
||||
if isinstance(content, dict):
|
||||
db_content = json.dumps(content, default=str, ensure_ascii=False)
|
||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||
else:
|
||||
db_content = content
|
||||
row = RunEventRow(
|
||||
thread_id=e["thread_id"],
|
||||
run_id=e["run_id"],
|
||||
user_id=e.get("user_id", user_id),
|
||||
event_type=e["event_type"],
|
||||
category=category,
|
||||
content=db_content,
|
||||
event_metadata=metadata,
|
||||
seq=seq,
|
||||
created_at=datetime.fromisoformat(e["created_at"]) if e.get("created_at") else datetime.now(UTC),
|
||||
)
|
||||
session.add(row)
|
||||
rows.append(row)
|
||||
return [self._row_to_dict(r) for r in rows]
|
||||
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
limit=50,
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq < before_seq)
|
||||
if after_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq > after_seq)
|
||||
|
||||
if after_seq is not None:
|
||||
# Forward pagination: first `limit` records after cursor
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
else:
|
||||
# before_seq or default (latest): take last `limit` records, return ascending
|
||||
stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
rows = list(result.scalars())
|
||||
return [self._row_to_dict(r) for r in reversed(rows)]
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
event_types=None,
|
||||
limit=500,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_events")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
if event_types:
|
||||
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
limit=50,
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages_by_run")
|
||||
stmt = select(RunEventRow).where(
|
||||
RunEventRow.thread_id == thread_id,
|
||||
RunEventRow.run_id == run_id,
|
||||
RunEventRow.category == "message",
|
||||
)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq < before_seq)
|
||||
if after_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq > after_seq)
|
||||
|
||||
if after_seq is not None:
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
else:
|
||||
stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
rows = list(result.scalars())
|
||||
return [self._row_to_dict(r) for r in reversed(rows)]
|
||||
|
||||
async def count_messages(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.count_messages")
|
||||
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
async with self._sf() as session:
|
||||
return await session.scalar(stmt) or 0
|
||||
|
||||
async def delete_by_thread(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_thread")
|
||||
async with self._sf() as session:
|
||||
count_conditions = [RunEventRow.thread_id == thread_id]
|
||||
if resolved_user_id is not None:
|
||||
count_conditions.append(RunEventRow.user_id == resolved_user_id)
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
await session.execute(delete(RunEventRow).where(*count_conditions))
|
||||
await session.commit()
|
||||
return count
|
||||
|
||||
async def delete_by_run(
|
||||
self,
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_run")
|
||||
async with self._sf() as session:
|
||||
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
|
||||
if resolved_user_id is not None:
|
||||
count_conditions.append(RunEventRow.user_id == resolved_user_id)
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
await session.execute(delete(RunEventRow).where(*count_conditions))
|
||||
await session.commit()
|
||||
return count
|
||||
@@ -1,187 +0,0 @@
|
||||
"""JSONL file-backed RunEventStore implementation.
|
||||
|
||||
Each run's events are stored in a single file:
|
||||
``.deer-flow/threads/{thread_id}/runs/{run_id}.jsonl``
|
||||
|
||||
All categories (message, trace, lifecycle) are in the same file.
|
||||
This backend is suitable for lightweight single-node deployments.
|
||||
|
||||
Known trade-off: ``list_messages()`` must scan all run files for a
|
||||
thread since messages from multiple runs need unified seq ordering.
|
||||
``list_events()`` reads only one file -- the fast path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SAFE_ID_PATTERN = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||
|
||||
|
||||
class JsonlRunEventStore(RunEventStore):
|
||||
def __init__(self, base_dir: str | Path | None = None):
|
||||
self._base_dir = Path(base_dir) if base_dir else Path(".deer-flow")
|
||||
self._seq_counters: dict[str, int] = {} # thread_id -> current max seq
|
||||
|
||||
@staticmethod
|
||||
def _validate_id(value: str, label: str) -> str:
|
||||
"""Validate that an ID is safe for use in filesystem paths."""
|
||||
if not value or not _SAFE_ID_PATTERN.match(value):
|
||||
raise ValueError(f"Invalid {label}: must be alphanumeric/dash/underscore, got {value!r}")
|
||||
return value
|
||||
|
||||
def _thread_dir(self, thread_id: str) -> Path:
|
||||
self._validate_id(thread_id, "thread_id")
|
||||
return self._base_dir / "threads" / thread_id / "runs"
|
||||
|
||||
def _run_file(self, thread_id: str, run_id: str) -> Path:
|
||||
self._validate_id(run_id, "run_id")
|
||||
return self._thread_dir(thread_id) / f"{run_id}.jsonl"
|
||||
|
||||
def _next_seq(self, thread_id: str) -> int:
|
||||
self._seq_counters[thread_id] = self._seq_counters.get(thread_id, 0) + 1
|
||||
return self._seq_counters[thread_id]
|
||||
|
||||
def _ensure_seq_loaded(self, thread_id: str) -> None:
|
||||
"""Load max seq from existing files if not yet cached."""
|
||||
if thread_id in self._seq_counters:
|
||||
return
|
||||
max_seq = 0
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if thread_dir.exists():
|
||||
for f in thread_dir.glob("*.jsonl"):
|
||||
for line in f.read_text(encoding="utf-8").strip().splitlines():
|
||||
try:
|
||||
record = json.loads(line)
|
||||
max_seq = max(max_seq, record.get("seq", 0))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSONL line in %s", f)
|
||||
continue
|
||||
self._seq_counters[thread_id] = max_seq
|
||||
|
||||
def _write_record(self, record: dict) -> None:
|
||||
path = self._run_file(record["thread_id"], record["run_id"])
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(record, default=str, ensure_ascii=False) + "\n")
|
||||
|
||||
def _read_thread_events(self, thread_id: str) -> list[dict]:
|
||||
"""Read all events for a thread, sorted by seq."""
|
||||
events = []
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if not thread_dir.exists():
|
||||
return events
|
||||
for f in sorted(thread_dir.glob("*.jsonl")):
|
||||
for line in f.read_text(encoding="utf-8").strip().splitlines():
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
events.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSONL line in %s", f)
|
||||
continue
|
||||
events.sort(key=lambda e: e.get("seq", 0))
|
||||
return events
|
||||
|
||||
def _read_run_events(self, thread_id: str, run_id: str) -> list[dict]:
|
||||
"""Read events for a specific run file."""
|
||||
path = self._run_file(thread_id, run_id)
|
||||
if not path.exists():
|
||||
return []
|
||||
events = []
|
||||
for line in path.read_text(encoding="utf-8").strip().splitlines():
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
events.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSONL line in %s", path)
|
||||
continue
|
||||
events.sort(key=lambda e: e.get("seq", 0))
|
||||
return events
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None):
|
||||
self._ensure_seq_loaded(thread_id)
|
||||
seq = self._next_seq(thread_id)
|
||||
record = {
|
||||
"thread_id": thread_id,
|
||||
"run_id": run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"seq": seq,
|
||||
"created_at": created_at or datetime.now(UTC).isoformat(),
|
||||
}
|
||||
self._write_record(record)
|
||||
return record
|
||||
|
||||
async def put_batch(self, events):
|
||||
if not events:
|
||||
return []
|
||||
results = []
|
||||
for ev in events:
|
||||
record = await self.put(**ev)
|
||||
results.append(record)
|
||||
return results
|
||||
|
||||
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
messages = [e for e in all_events if e.get("category") == "message"]
|
||||
|
||||
if before_seq is not None:
|
||||
messages = [e for e in messages if e["seq"] < before_seq]
|
||||
return messages[-limit:]
|
||||
elif after_seq is not None:
|
||||
messages = [e for e in messages if e["seq"] > after_seq]
|
||||
return messages[:limit]
|
||||
else:
|
||||
return messages[-limit:]
|
||||
|
||||
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
if event_types is not None:
|
||||
events = [e for e in events if e.get("event_type") in event_types]
|
||||
return events[:limit]
|
||||
|
||||
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
filtered = [e for e in events if e.get("category") == "message"]
|
||||
if before_seq is not None:
|
||||
filtered = [e for e in filtered if e.get("seq", 0) < before_seq]
|
||||
if after_seq is not None:
|
||||
filtered = [e for e in filtered if e.get("seq", 0) > after_seq]
|
||||
if after_seq is not None:
|
||||
return filtered[:limit]
|
||||
else:
|
||||
return filtered[-limit:] if len(filtered) > limit else filtered
|
||||
|
||||
async def count_messages(self, thread_id):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
return sum(1 for e in all_events if e.get("category") == "message")
|
||||
|
||||
async def delete_by_thread(self, thread_id):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
count = len(all_events)
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if thread_dir.exists():
|
||||
for f in thread_dir.glob("*.jsonl"):
|
||||
f.unlink()
|
||||
self._seq_counters.pop(thread_id, None)
|
||||
return count
|
||||
|
||||
async def delete_by_run(self, thread_id, run_id):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
count = len(events)
|
||||
path = self._run_file(thread_id, run_id)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
return count
|
||||
@@ -1,128 +0,0 @@
|
||||
"""In-memory RunEventStore. Used when run_events.backend=memory (default) and in tests.
|
||||
|
||||
Thread-safe for single-process async usage (no threading locks needed
|
||||
since all mutations happen within the same event loop).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
|
||||
|
||||
class MemoryRunEventStore(RunEventStore):
|
||||
def __init__(self) -> None:
|
||||
self._events: dict[str, list[dict]] = {} # thread_id -> sorted event list
|
||||
self._seq_counters: dict[str, int] = {} # thread_id -> last assigned seq
|
||||
|
||||
def _next_seq(self, thread_id: str) -> int:
|
||||
current = self._seq_counters.get(thread_id, 0)
|
||||
next_val = current + 1
|
||||
self._seq_counters[thread_id] = next_val
|
||||
return next_val
|
||||
|
||||
def _put_one(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
event_type: str,
|
||||
category: str,
|
||||
content: str | dict = "",
|
||||
metadata: dict | None = None,
|
||||
created_at: str | None = None,
|
||||
) -> dict:
|
||||
seq = self._next_seq(thread_id)
|
||||
record = {
|
||||
"thread_id": thread_id,
|
||||
"run_id": run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"seq": seq,
|
||||
"created_at": created_at or datetime.now(UTC).isoformat(),
|
||||
}
|
||||
self._events.setdefault(thread_id, []).append(record)
|
||||
return record
|
||||
|
||||
async def put(
|
||||
self,
|
||||
*,
|
||||
thread_id,
|
||||
run_id,
|
||||
event_type,
|
||||
category,
|
||||
content="",
|
||||
metadata=None,
|
||||
created_at=None,
|
||||
):
|
||||
return self._put_one(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
event_type=event_type,
|
||||
category=category,
|
||||
content=content,
|
||||
metadata=metadata,
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
async def put_batch(self, events):
|
||||
results = []
|
||||
for ev in events:
|
||||
record = self._put_one(**ev)
|
||||
results.append(record)
|
||||
return results
|
||||
|
||||
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
messages = [e for e in all_events if e["category"] == "message"]
|
||||
|
||||
if before_seq is not None:
|
||||
messages = [e for e in messages if e["seq"] < before_seq]
|
||||
# Take the last `limit` records
|
||||
return messages[-limit:]
|
||||
elif after_seq is not None:
|
||||
messages = [e for e in messages if e["seq"] > after_seq]
|
||||
return messages[:limit]
|
||||
else:
|
||||
# Return the latest `limit` records, ascending
|
||||
return messages[-limit:]
|
||||
|
||||
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
filtered = [e for e in all_events if e["run_id"] == run_id]
|
||||
if event_types is not None:
|
||||
filtered = [e for e in filtered if e["event_type"] in event_types]
|
||||
return filtered[:limit]
|
||||
|
||||
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
filtered = [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"]
|
||||
if before_seq is not None:
|
||||
filtered = [e for e in filtered if e["seq"] < before_seq]
|
||||
if after_seq is not None:
|
||||
filtered = [e for e in filtered if e["seq"] > after_seq]
|
||||
if after_seq is not None:
|
||||
return filtered[:limit]
|
||||
else:
|
||||
return filtered[-limit:] if len(filtered) > limit else filtered
|
||||
|
||||
async def count_messages(self, thread_id):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
return sum(1 for e in all_events if e["category"] == "message")
|
||||
|
||||
async def delete_by_thread(self, thread_id):
|
||||
events = self._events.pop(thread_id, [])
|
||||
self._seq_counters.pop(thread_id, None)
|
||||
return len(events)
|
||||
|
||||
async def delete_by_run(self, thread_id, run_id):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
if not all_events:
|
||||
return 0
|
||||
remaining = [e for e in all_events if e["run_id"] != run_id]
|
||||
removed = len(all_events) - len(remaining)
|
||||
self._events[thread_id] = remaining
|
||||
return removed
|
||||
@@ -1,374 +0,0 @@
|
||||
"""Run event capture via LangChain callbacks.
|
||||
|
||||
RunJournal sits between LangChain's callback mechanism and the pluggable
|
||||
RunEventStore. It standardizes callback data into RunEvent records and
|
||||
handles token usage accumulation.
|
||||
|
||||
Key design decisions:
|
||||
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
|
||||
- on_chat_model_start captures structured prompts as llm_request (OpenAI format) and
|
||||
extracts the first human message for run.input, because it is more reliable than
|
||||
on_chain_start (fires on every node) — messages here are fully structured.
|
||||
- on_chain_start with parent_run_id=None emits a run.start trace marking root invocation.
|
||||
- on_llm_end emits llm_response in OpenAI Chat Completions format
|
||||
- Token usage accumulated in memory, written to RunRow on run completion
|
||||
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunJournal(BaseCallbackHandler):
|
||||
"""LangChain callback handler that captures events to RunEventStore."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
event_store: RunEventStore,
|
||||
*,
|
||||
track_token_usage: bool = True,
|
||||
flush_threshold: int = 20,
|
||||
):
|
||||
super().__init__()
|
||||
self.run_id = run_id
|
||||
self.thread_id = thread_id
|
||||
self._store = event_store
|
||||
self._track_tokens = track_token_usage
|
||||
self._flush_threshold = flush_threshold
|
||||
|
||||
# Write buffer
|
||||
self._buffer: list[dict] = []
|
||||
self._pending_flush_tasks: set[asyncio.Task[None]] = set()
|
||||
|
||||
# Token accumulators
|
||||
self._total_input_tokens = 0
|
||||
self._total_output_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._llm_call_count = 0
|
||||
self._lead_agent_tokens = 0
|
||||
self._subagent_tokens = 0
|
||||
self._middleware_tokens = 0
|
||||
|
||||
# Convenience fields
|
||||
self._last_ai_msg: str | None = None
|
||||
self._first_human_msg: str | None = None
|
||||
self._msg_count = 0
|
||||
|
||||
# Latency tracking
|
||||
self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time
|
||||
|
||||
# LLM request/response tracking
|
||||
self._llm_call_index = 0
|
||||
self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages
|
||||
|
||||
# -- Lifecycle callbacks --
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
tags: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
caller = self._identify_caller(tags)
|
||||
if parent_run_id is None:
|
||||
# Root graph invocation — emit a single trace event for the run start.
|
||||
chain_name = (serialized or {}).get("name", "unknown")
|
||||
self._put(
|
||||
event_type="run.start",
|
||||
category="trace",
|
||||
content={"chain": chain_name},
|
||||
metadata={"caller": caller, **(metadata or {})},
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"})
|
||||
self._flush_sync()
|
||||
|
||||
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._put(
|
||||
event_type="run.error",
|
||||
category="error",
|
||||
content=str(error),
|
||||
metadata={"error_type": type(error).__name__},
|
||||
)
|
||||
self._flush_sync()
|
||||
|
||||
# -- LLM callbacks --
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict,
|
||||
messages: list[list[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Capture structured prompt messages for llm_request event.
|
||||
|
||||
This is also the canonical place to extract the first human message:
|
||||
messages are fully structured here, it fires only on real LLM calls,
|
||||
and the content is never compressed by checkpoint trimming.
|
||||
"""
|
||||
rid = str(run_id)
|
||||
self._llm_start_times[rid] = time.monotonic()
|
||||
self._llm_call_index += 1
|
||||
# Mark this run_id as seen so on_llm_end knows not to increment again.
|
||||
self._cached_prompts[rid] = []
|
||||
|
||||
logger.info(f"on_chat_model_start {run_id}: tags={tags} serialized={serialized} messages={messages}")
|
||||
|
||||
# Capture the first human message sent to any LLM in this run.
|
||||
if not self._first_human_msg:
|
||||
for batch in messages.reversed():
|
||||
for m in batch.reversed():
|
||||
if isinstance(m, HumanMessage) and m.name != "summary":
|
||||
caller = self._identify_caller(tags)
|
||||
self.set_first_human_message(m.text)
|
||||
self._put(
|
||||
event_type="llm.human.input",
|
||||
category="message",
|
||||
content=m.model_dump(),
|
||||
metadata={"caller": caller},
|
||||
)
|
||||
break
|
||||
if self._first_human_msg:
|
||||
break
|
||||
|
||||
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, parent_run_id: UUID | None = None, tags: list[str] | None = None, metadata: dict[str, Any] | None = None, **kwargs: Any) -> None:
|
||||
# Fallback: on_chat_model_start is preferred. This just tracks latency.
|
||||
self._llm_start_times[str(run_id)] = time.monotonic()
|
||||
|
||||
def on_llm_end(self, response, *, run_id, parent_run_id, tags, **kwargs) -> None:
|
||||
messages: list[AnyMessage] = []
|
||||
logger.info(f"on_llm_end {run_id}: response: {tags} {kwargs}")
|
||||
for generation in response.generations:
|
||||
for gen in generation:
|
||||
if hasattr(gen, "message"):
|
||||
messages.append(gen.message)
|
||||
else:
|
||||
logger.warning(f"on_llm_end {run_id}: generation has no message attribute: {gen}")
|
||||
|
||||
for message in messages:
|
||||
caller = self._identify_caller(tags)
|
||||
|
||||
# Latency
|
||||
rid = str(run_id)
|
||||
start = self._llm_start_times.pop(rid, None)
|
||||
latency_ms = int((time.monotonic() - start) * 1000) if start else None
|
||||
|
||||
# Token usage from message
|
||||
usage = getattr(message, "usage_metadata", None)
|
||||
usage_dict = dict(usage) if usage else {}
|
||||
|
||||
# Resolve call index
|
||||
call_index = self._llm_call_index
|
||||
if rid not in self._cached_prompts:
|
||||
# Fallback: on_chat_model_start was not called
|
||||
self._llm_call_index += 1
|
||||
call_index = self._llm_call_index
|
||||
|
||||
# Trace event: llm_response (OpenAI completion format)
|
||||
self._put(
|
||||
event_type="llm.ai.response",
|
||||
category="message",
|
||||
content=message.model_dump(),
|
||||
metadata={
|
||||
"caller": caller,
|
||||
"usage": usage_dict,
|
||||
"latency_ms": latency_ms,
|
||||
"llm_call_index": call_index,
|
||||
},
|
||||
)
|
||||
|
||||
# Token accumulation
|
||||
if self._track_tokens:
|
||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||
if total_tk == 0:
|
||||
total_tk = input_tk + output_tk
|
||||
if total_tk > 0:
|
||||
self._total_input_tokens += input_tk
|
||||
self._total_output_tokens += output_tk
|
||||
self._total_tokens += total_tk
|
||||
self._llm_call_count += 1
|
||||
|
||||
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._llm_start_times.pop(str(run_id), None)
|
||||
self._put(event_type="llm.error", category="trace", content=str(error))
|
||||
|
||||
def on_tool_start(self, serialized, input_str, *, run_id, parent_run_id=None, tags=None, metadata=None, inputs=None, **kwargs):
|
||||
"""Handle tool start event, cache tool call ID for later correlation"""
|
||||
tool_call_id = str(run_id)
|
||||
logger.info(f"Tool start for node {run_id}, tool_call_id={tool_call_id}, tags={tags}, metadata={metadata}")
|
||||
|
||||
def on_tool_end(self, output, *, run_id, parent_run_id=None, **kwargs):
|
||||
"""Handle tool end event, append message and clear node data"""
|
||||
try:
|
||||
if isinstance(output, ToolMessage):
|
||||
msg = cast(ToolMessage, output)
|
||||
self._put(event_type="llm.tool.result", category="message", content=msg.model_dump())
|
||||
elif isinstance(output, Command):
|
||||
cmd = cast(Command, output)
|
||||
messages = cmd.update.get("messages", [])
|
||||
for message in messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
self._put(event_type="llm.tool.result", category="message", content=message.model_dump())
|
||||
else:
|
||||
logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}")
|
||||
else:
|
||||
logger.warning(f"on_tool_end {run_id}: output is not ToolMessage: {type(output)}")
|
||||
finally:
|
||||
logger.info(f"Tool end for node {run_id}")
|
||||
|
||||
# -- Internal methods --
|
||||
|
||||
def _put(self, *, event_type: str, category: str, content: str | dict = "", metadata: dict | None = None) -> None:
|
||||
self._buffer.append(
|
||||
{
|
||||
"thread_id": self.thread_id,
|
||||
"run_id": self.run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
)
|
||||
if len(self._buffer) >= self._flush_threshold:
|
||||
self._flush_sync()
|
||||
|
||||
def _flush_sync(self) -> None:
|
||||
"""Best-effort flush of buffer to RunEventStore.
|
||||
|
||||
BaseCallbackHandler methods are synchronous. If an event loop is
|
||||
running we schedule an async ``put_batch``; otherwise the events
|
||||
stay in the buffer and are flushed later by the async ``flush()``
|
||||
call in the worker's ``finally`` block.
|
||||
"""
|
||||
if not self._buffer:
|
||||
return
|
||||
# Skip if a flush is already in flight — avoids concurrent writes
|
||||
# to the same SQLite file from multiple fire-and-forget tasks.
|
||||
if self._pending_flush_tasks:
|
||||
return
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# No event loop — keep events in buffer for later async flush.
|
||||
return
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
task = loop.create_task(self._flush_async(batch))
|
||||
self._pending_flush_tasks.add(task)
|
||||
task.add_done_callback(self._on_flush_done)
|
||||
|
||||
async def _flush_async(self, batch: list[dict]) -> None:
|
||||
try:
|
||||
await self._store.put_batch(batch)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to flush %d events for run %s — returning to buffer",
|
||||
len(batch),
|
||||
self.run_id,
|
||||
exc_info=True,
|
||||
)
|
||||
# Return failed events to buffer for retry on next flush
|
||||
self._buffer = batch + self._buffer
|
||||
|
||||
def _on_flush_done(self, task: asyncio.Task) -> None:
|
||||
self._pending_flush_tasks.discard(task)
|
||||
if task.cancelled():
|
||||
return
|
||||
exc = task.exception()
|
||||
if exc:
|
||||
logger.warning("Journal flush task failed: %s", exc)
|
||||
|
||||
def _identify_caller(self, tags: list[str] | None, **kwargs) -> str:
|
||||
_tags = tags or kwargs.get("tags", [])
|
||||
for tag in _tags:
|
||||
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
|
||||
return tag
|
||||
# Default to lead_agent: the main agent graph does not inject
|
||||
# callback tags, while subagents and middleware explicitly tag
|
||||
# themselves.
|
||||
return "lead_agent"
|
||||
|
||||
# -- Public methods (called by worker) --
|
||||
|
||||
def set_first_human_message(self, content: str) -> None:
|
||||
"""Record the first human message for convenience fields."""
|
||||
self._first_human_msg = content[:2000] if content else None
|
||||
|
||||
def record_middleware(self, tag: str, *, name: str, hook: str, action: str, changes: dict) -> None:
|
||||
"""Record a middleware state-change event.
|
||||
|
||||
Called by middleware implementations when they perform a meaningful
|
||||
state change (e.g., title generation, summarization, HITL approval).
|
||||
Pure-observation middleware should not call this.
|
||||
|
||||
Args:
|
||||
tag: Short identifier for the middleware (e.g., "title", "summarize",
|
||||
"guardrail"). Used to form event_type="middleware:{tag}".
|
||||
name: Full middleware class name.
|
||||
hook: Lifecycle hook that triggered the action (e.g., "after_model").
|
||||
action: Specific action performed (e.g., "generate_title").
|
||||
changes: Dict describing the state changes made.
|
||||
"""
|
||||
self._put(
|
||||
event_type=f"middleware:{tag}",
|
||||
category="middleware",
|
||||
content={"name": name, "hook": hook, "action": action, "changes": changes},
|
||||
)
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Force flush remaining buffer. Called in worker's finally block."""
|
||||
if self._pending_flush_tasks:
|
||||
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
|
||||
|
||||
while self._buffer:
|
||||
batch = self._buffer[: self._flush_threshold]
|
||||
del self._buffer[: self._flush_threshold]
|
||||
try:
|
||||
await self._store.put_batch(batch)
|
||||
except Exception:
|
||||
self._buffer = batch + self._buffer
|
||||
raise
|
||||
|
||||
def get_completion_data(self) -> dict:
|
||||
"""Return accumulated token and message data for run completion."""
|
||||
return {
|
||||
"total_input_tokens": self._total_input_tokens,
|
||||
"total_output_tokens": self._total_output_tokens,
|
||||
"total_tokens": self._total_tokens,
|
||||
"llm_call_count": self._llm_call_count,
|
||||
"lead_agent_tokens": self._lead_agent_tokens,
|
||||
"subagent_tokens": self._subagent_tokens,
|
||||
"middleware_tokens": self._middleware_tokens,
|
||||
"message_count": self._msg_count,
|
||||
"last_ai_message": self._last_ai_msg,
|
||||
"first_human_message": self._first_human_msg,
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
"""Store provider for the DeerFlow runtime.
|
||||
|
||||
Re-exports the public API of both the async provider (for long-running
|
||||
servers) and the sync provider (for CLI tools and the embedded client).
|
||||
|
||||
Async usage (FastAPI lifespan)::
|
||||
|
||||
from deerflow.runtime.store import make_store
|
||||
|
||||
async with make_store() as store:
|
||||
app.state.store = store
|
||||
|
||||
Sync usage (CLI / DeerFlowClient)::
|
||||
|
||||
from deerflow.runtime.store import get_store, store_context
|
||||
|
||||
store = get_store() # singleton
|
||||
with store_context() as store: ... # one-shot
|
||||
"""
|
||||
|
||||
from .async_provider import make_store
|
||||
from .provider import get_store, reset_store, store_context
|
||||
|
||||
__all__ = [
|
||||
# async
|
||||
"make_store",
|
||||
# sync
|
||||
"get_store",
|
||||
"reset_store",
|
||||
"store_context",
|
||||
]
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Shared SQLite connection utilities for store and checkpointer providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pathlib
|
||||
|
||||
from deerflow.config.paths import resolve_path
|
||||
|
||||
|
||||
def resolve_sqlite_conn_str(raw: str) -> str:
|
||||
"""Return a SQLite connection string ready for use with store/checkpointer backends.
|
||||
|
||||
SQLite special strings (``":memory:"`` and ``file:`` URIs) are returned
|
||||
unchanged. Plain filesystem paths — relative or absolute — are resolved
|
||||
to an absolute string via :func:`resolve_path`.
|
||||
"""
|
||||
if raw == ":memory:" or raw.startswith("file:"):
|
||||
return raw
|
||||
return str(resolve_path(raw))
|
||||
|
||||
|
||||
def ensure_sqlite_parent_dir(conn_str: str) -> None:
|
||||
"""Create parent directory for a SQLite filesystem path.
|
||||
|
||||
No-op for in-memory databases (``":memory:"``) and ``file:`` URIs.
|
||||
"""
|
||||
if conn_str != ":memory:" and not conn_str.startswith("file:"):
|
||||
pathlib.Path(conn_str).parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -1,113 +0,0 @@
|
||||
"""Async Store factory — backend mirrors the configured checkpointer.
|
||||
|
||||
The store and checkpointer share the same ``checkpointer`` section in
|
||||
*config.yaml* so they always use the same persistence backend:
|
||||
|
||||
- ``type: memory`` → :class:`langgraph.store.memory.InMemoryStore`
|
||||
- ``type: sqlite`` → :class:`langgraph.store.sqlite.aio.AsyncSqliteStore`
|
||||
- ``type: postgres`` → :class:`langgraph.store.postgres.aio.AsyncPostgresStore`
|
||||
|
||||
Usage (e.g. FastAPI lifespan)::
|
||||
|
||||
from deerflow.runtime.store import make_store
|
||||
|
||||
async with make_store() as store:
|
||||
app.state.store = store
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.runtime.store.provider import POSTGRES_CONN_REQUIRED, POSTGRES_STORE_INSTALL, SQLITE_STORE_INSTALL, ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal backend factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _async_store(config) -> AsyncIterator[BaseStore]:
|
||||
"""Async context manager that constructs and tears down a Store.
|
||||
|
||||
The ``config`` argument is a :class:`deerflow.config.checkpointer_config.CheckpointerConfig`
|
||||
instance — the same object used by the checkpointer factory.
|
||||
"""
|
||||
if config.type == "memory":
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
logger.info("Store: using InMemoryStore (in-process, not persistent)")
|
||||
yield InMemoryStore()
|
||||
return
|
||||
|
||||
if config.type == "sqlite":
|
||||
try:
|
||||
from langgraph.store.sqlite.aio import AsyncSqliteStore
|
||||
except ImportError as exc:
|
||||
raise ImportError(SQLITE_STORE_INSTALL) from exc
|
||||
|
||||
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
|
||||
ensure_sqlite_parent_dir(conn_str)
|
||||
|
||||
async with AsyncSqliteStore.from_conn_string(conn_str) as store:
|
||||
await store.setup()
|
||||
logger.info("Store: using AsyncSqliteStore (%s)", conn_str)
|
||||
yield store
|
||||
return
|
||||
|
||||
if config.type == "postgres":
|
||||
try:
|
||||
from langgraph.store.postgres.aio import AsyncPostgresStore # type: ignore[import]
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_STORE_INSTALL) from exc
|
||||
|
||||
if not config.connection_string:
|
||||
raise ValueError(POSTGRES_CONN_REQUIRED)
|
||||
|
||||
async with AsyncPostgresStore.from_conn_string(config.connection_string) as store:
|
||||
await store.setup()
|
||||
logger.info("Store: using AsyncPostgresStore")
|
||||
yield store
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown store backend type: {config.type!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public async context manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_store() -> AsyncIterator[BaseStore]:
|
||||
"""Async context manager that yields a Store whose backend matches the
|
||||
configured checkpointer.
|
||||
|
||||
Reads from the same ``checkpointer`` section of *config.yaml* used by
|
||||
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer` so
|
||||
that both singletons always use the same persistence technology::
|
||||
|
||||
async with make_store() as store:
|
||||
app.state.store = store
|
||||
|
||||
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
|
||||
``checkpointer`` section is configured (emits a WARNING in that case).
|
||||
"""
|
||||
config = get_app_config()
|
||||
|
||||
if config.checkpointer is None:
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
|
||||
yield InMemoryStore()
|
||||
return
|
||||
|
||||
async with _async_store(config.checkpointer) as store:
|
||||
yield store
|
||||
@@ -1,188 +0,0 @@
|
||||
"""Sync Store factory.
|
||||
|
||||
Provides a **sync singleton** and a **sync context manager** for CLI tools
|
||||
and the embedded :class:`~deerflow.client.DeerFlowClient`.
|
||||
|
||||
The backend mirrors the configured checkpointer so that both always use the
|
||||
same persistence technology. Supported backends: memory, sqlite, postgres.
|
||||
|
||||
Usage::
|
||||
|
||||
from deerflow.runtime.store.provider import get_store, store_context
|
||||
|
||||
# Singleton — reused across calls, closed on process exit
|
||||
store = get_store()
|
||||
|
||||
# One-shot — fresh connection, closed on block exit
|
||||
with store_context() as store:
|
||||
store.put(("ns",), "key", {"value": 1})
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error message constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SQLITE_STORE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite store. Install it with: uv add langgraph-checkpoint-sqlite"
|
||||
POSTGRES_STORE_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL store. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
|
||||
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _sync_store_cm(config) -> Iterator[BaseStore]:
|
||||
"""Context manager that creates and tears down a sync Store.
|
||||
|
||||
The ``config`` argument is a
|
||||
:class:`~deerflow.config.checkpointer_config.CheckpointerConfig` instance —
|
||||
the same object used by the checkpointer factory.
|
||||
"""
|
||||
if config.type == "memory":
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
logger.info("Store: using InMemoryStore (in-process, not persistent)")
|
||||
yield InMemoryStore()
|
||||
return
|
||||
|
||||
if config.type == "sqlite":
|
||||
try:
|
||||
from langgraph.store.sqlite import SqliteStore
|
||||
except ImportError as exc:
|
||||
raise ImportError(SQLITE_STORE_INSTALL) from exc
|
||||
|
||||
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
|
||||
ensure_sqlite_parent_dir(conn_str)
|
||||
|
||||
with SqliteStore.from_conn_string(conn_str) as store:
|
||||
store.setup()
|
||||
logger.info("Store: using SqliteStore (%s)", conn_str)
|
||||
yield store
|
||||
return
|
||||
|
||||
if config.type == "postgres":
|
||||
try:
|
||||
from langgraph.store.postgres import PostgresStore # type: ignore[import]
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_STORE_INSTALL) from exc
|
||||
|
||||
if not config.connection_string:
|
||||
raise ValueError(POSTGRES_CONN_REQUIRED)
|
||||
|
||||
with PostgresStore.from_conn_string(config.connection_string) as store:
|
||||
store.setup()
|
||||
logger.info("Store: using PostgresStore")
|
||||
yield store
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown store backend type: {config.type!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_store: BaseStore | None = None
|
||||
_store_ctx = None # open context manager keeping the connection alive
|
||||
|
||||
|
||||
def get_store() -> BaseStore:
|
||||
"""Return the global sync Store singleton, creating it on first call.
|
||||
|
||||
Returns an :class:`~langgraph.store.memory.InMemoryStore` when no
|
||||
checkpointer is configured in *config.yaml* (emits a WARNING in that case).
|
||||
|
||||
Raises:
|
||||
ImportError: If the required package for the configured backend is not installed.
|
||||
ValueError: If ``connection_string`` is missing for a backend that requires it.
|
||||
"""
|
||||
global _store, _store_ctx
|
||||
|
||||
if _store is not None:
|
||||
return _store
|
||||
|
||||
# Lazily load app config, mirroring the checkpointer singleton pattern so
|
||||
# that tests that set the global checkpointer config explicitly remain isolated.
|
||||
from deerflow.config.app_config import _app_config
|
||||
from deerflow.config.checkpointer_config import get_checkpointer_config
|
||||
|
||||
config = get_checkpointer_config()
|
||||
|
||||
if config is None and _app_config is None:
|
||||
try:
|
||||
get_app_config()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
config = get_checkpointer_config()
|
||||
|
||||
if config is None:
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
|
||||
_store = InMemoryStore()
|
||||
return _store
|
||||
|
||||
_store_ctx = _sync_store_cm(config)
|
||||
_store = _store_ctx.__enter__()
|
||||
return _store
|
||||
|
||||
|
||||
def reset_store() -> None:
|
||||
"""Reset the sync singleton, forcing recreation on the next call.
|
||||
|
||||
Closes any open backend connections and clears the cached instance.
|
||||
Useful in tests or after a configuration change.
|
||||
"""
|
||||
global _store, _store_ctx
|
||||
if _store_ctx is not None:
|
||||
try:
|
||||
_store_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error during store cleanup", exc_info=True)
|
||||
_store_ctx = None
|
||||
_store = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync context manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def store_context() -> Iterator[BaseStore]:
|
||||
"""Sync context manager that yields a Store and cleans up on exit.
|
||||
|
||||
Unlike :func:`get_store`, this does **not** cache the instance — each
|
||||
``with`` block creates and destroys its own connection. Use it in CLI
|
||||
scripts or tests where you want deterministic cleanup::
|
||||
|
||||
with store_context() as store:
|
||||
store.put(("threads",), thread_id, {...})
|
||||
|
||||
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
|
||||
checkpointer is configured in *config.yaml*.
|
||||
"""
|
||||
config = get_app_config()
|
||||
if config.checkpointer is None:
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
|
||||
yield InMemoryStore()
|
||||
return
|
||||
|
||||
with _sync_store_cm(config.checkpointer) as store:
|
||||
yield store
|
||||
Reference in New Issue
Block a user