feat(persistence): add ORM models, repositories, DB/JSONL event stores, RunJournal, and API endpoints

Phase 2-B: run persistence + event storage + token tracking.

- ORM models: RunRow (with token fields), ThreadMetaRow, RunEventRow
- RunRepository implements RunStore ABC via SQLAlchemy ORM
- ThreadMetaRepository with owner access control
- DbRunEventStore with trace content truncation and cursor pagination
- JsonlRunEventStore with per-run files and seq recovery from disk
- RunJournal (BaseCallbackHandler) captures LLM/tool/lifecycle events,
  accumulates token usage by caller type, buffers and flushes to store
- RunManager now accepts optional RunStore for persistent backing
- Worker creates RunJournal, writes human_message, injects callbacks
- Gateway deps use factory functions (RunRepository when DB available)
- New endpoints: messages, run messages, run events, token-usage
- ThreadCreateRequest gains assistant_id field
- 92 tests pass (33 new), zero regressions

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-02 19:03:38 +08:00
parent 23eacf9533
commit e3179cd54d
21 changed files with 1946 additions and 29 deletions
+32 -10
View File
@@ -31,27 +31,23 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
from deerflow.config import get_app_config
from deerflow.persistence.engine import close_engine, init_engine_from_config
from deerflow.runtime import make_store, make_stream_bridge
from deerflow.runtime.runs.store.memory import MemoryRunStore
async with AsyncExitStack() as stack:
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
app.state.store = await stack.enter_async_context(make_store())
app.state.run_manager = RunManager()
# Initialize persistence layer from unified database config
config = get_app_config()
await init_engine_from_config(config.database)
# Initialize run store (MemoryRunStore for now; switch to ORM-backed
# RunRepository when models are implemented)
app.state.run_store = MemoryRunStore()
# Initialize run store (RunRepository if DB available, else MemoryRunStore)
app.state.run_store = _make_run_store()
# Initialize run event store (MemoryRunEventStore for now)
# TODO(Phase 2-B): switch to db/jsonl backend based on config.run_events.backend
from deerflow.runtime.events.store.memory import MemoryRunEventStore
# Initialize run event store based on config
app.state.run_event_store = _make_run_event_store(config)
app.state.run_event_store = MemoryRunEventStore()
# RunManager with store backing for persistence
app.state.run_manager = RunManager(store=app.state.run_store)
try:
yield
@@ -59,6 +55,32 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
await close_engine()
# ---------------------------------------------------------------------------
# Factories
# ---------------------------------------------------------------------------
def _make_run_store() -> RunStore:
"""Create a RunStore: RunRepository if DB engine is available, else MemoryRunStore."""
from deerflow.persistence.engine import get_session_factory
sf = get_session_factory()
if sf is not None:
from deerflow.persistence.repositories.run_repo import RunRepository
return RunRepository(sf)
from deerflow.runtime.runs.store.memory import MemoryRunStore
return MemoryRunStore()
def _make_run_event_store(config) -> RunEventStore:
from deerflow.runtime.events.store import make_run_event_store
run_events_config = getattr(config, "run_events", None)
return make_run_event_store(run_events_config)
# ---------------------------------------------------------------------------
# Getters -- called by routers per-request
# ---------------------------------------------------------------------------
+75 -1
View File
@@ -19,7 +19,7 @@ from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel, Field
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.services import sse_consumer, start_run
from deerflow.runtime import RunRecord, serialize_channel_values
@@ -263,3 +263,77 @@ async def stream_existing_run(
"X-Accel-Buffering": "no",
},
)
# ---------------------------------------------------------------------------
# Messages / Events / Token usage endpoints
# ---------------------------------------------------------------------------
@router.get("/{thread_id}/messages")
async def list_thread_messages(
thread_id: str,
request: Request,
limit: int = Query(default=50, le=200),
before_seq: int | None = Query(default=None),
after_seq: int | None = Query(default=None),
) -> list[dict]:
"""Return displayable messages for a thread (across all runs)."""
event_store = get_run_event_store(request)
return await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
@router.get("/{thread_id}/runs/{run_id}/messages")
async def list_run_messages(thread_id: str, run_id: str, request: Request) -> list[dict]:
"""Return displayable messages for a specific run."""
event_store = get_run_event_store(request)
return await event_store.list_messages_by_run(thread_id, run_id)
@router.get("/{thread_id}/runs/{run_id}/events")
async def list_run_events(
thread_id: str,
run_id: str,
request: Request,
event_types: str | None = Query(default=None),
limit: int = Query(default=500, le=2000),
) -> list[dict]:
"""Return the full event stream for a run (debug/audit)."""
event_store = get_run_event_store(request)
types = event_types.split(",") if event_types else None
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
@router.get("/{thread_id}/token-usage")
async def thread_token_usage(thread_id: str, request: Request) -> dict:
"""Thread-level token usage aggregation."""
run_store = get_run_store(request)
runs = await run_store.list_by_thread(thread_id, limit=10000)
completed = [r for r in runs if r.get("status") in ("success", "error")]
total_tokens = sum(r.get("total_tokens", 0) for r in completed)
total_input = sum(r.get("total_input_tokens", 0) for r in completed)
total_output = sum(r.get("total_output_tokens", 0) for r in completed)
by_model: dict[str, dict] = {}
for r in completed:
model = r.get("model_name") or "unknown"
entry = by_model.setdefault(model, {"tokens": 0, "runs": 0})
entry["tokens"] += r.get("total_tokens", 0)
entry["runs"] += 1
by_caller = {
"lead_agent": sum(r.get("lead_agent_tokens", 0) for r in completed),
"subagent": sum(r.get("subagent_tokens", 0) for r in completed),
"middleware": sum(r.get("middleware_tokens", 0) for r in completed),
}
return {
"thread_id": thread_id,
"total_tokens": total_tokens,
"total_input_tokens": total_input,
"total_output_tokens": total_output,
"total_runs": len(completed),
"by_model": by_model,
"by_caller": by_caller,
}
+1
View File
@@ -63,6 +63,7 @@ class ThreadCreateRequest(BaseModel):
"""Request body for creating a thread."""
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
+9 -1
View File
@@ -17,7 +17,7 @@ from typing import Any
from fastapi import HTTPException, Request
from langchain_core.messages import HumanMessage
from app.gateway.deps import get_checkpointer, get_run_manager, get_store, get_stream_bridge
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_store, get_stream_bridge
from deerflow.runtime import (
END_SENTINEL,
HEARTBEAT_SENTINEL,
@@ -245,6 +245,12 @@ async def start_run(
run_mgr = get_run_manager(request)
checkpointer = get_checkpointer(request)
store = get_store(request)
event_store = get_run_event_store(request)
# Get run_events config for journal
from deerflow.config import get_app_config
run_events_config = getattr(get_app_config(), "run_events", None)
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
@@ -287,6 +293,8 @@ async def start_run(
stream_subgraphs=body.stream_subgraphs,
interrupt_before=body.interrupt_before,
interrupt_after=body.interrupt_after,
event_store=event_store,
run_events_config=run_events_config,
)
)
record.task = task