mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 00:16:48 +00:00
feat(persistence): add ORM models, repositories, DB/JSONL event stores, RunJournal, and API endpoints
Phase 2-B: run persistence + event storage + token tracking. - ORM models: RunRow (with token fields), ThreadMetaRow, RunEventRow - RunRepository implements RunStore ABC via SQLAlchemy ORM - ThreadMetaRepository with owner access control - DbRunEventStore with trace content truncation and cursor pagination - JsonlRunEventStore with per-run files and seq recovery from disk - RunJournal (BaseCallbackHandler) captures LLM/tool/lifecycle events, accumulates token usage by caller type, buffers and flushes to store - RunManager now accepts optional RunStore for persistent backing - Worker creates RunJournal, writes human_message, injects callbacks - Gateway deps use factory functions (RunRepository when DB available) - New endpoints: messages, run messages, run events, token-usage - ThreadCreateRequest gains assistant_id field - 92 tests pass (33 new), zero regressions Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
+32
-10
@@ -31,27 +31,23 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.persistence.engine import close_engine, init_engine_from_config
|
||||
from deerflow.runtime import make_store, make_stream_bridge
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
|
||||
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
|
||||
app.state.store = await stack.enter_async_context(make_store())
|
||||
app.state.run_manager = RunManager()
|
||||
|
||||
# Initialize persistence layer from unified database config
|
||||
config = get_app_config()
|
||||
await init_engine_from_config(config.database)
|
||||
|
||||
# Initialize run store (MemoryRunStore for now; switch to ORM-backed
|
||||
# RunRepository when models are implemented)
|
||||
app.state.run_store = MemoryRunStore()
|
||||
# Initialize run store (RunRepository if DB available, else MemoryRunStore)
|
||||
app.state.run_store = _make_run_store()
|
||||
|
||||
# Initialize run event store (MemoryRunEventStore for now)
|
||||
# TODO(Phase 2-B): switch to db/jsonl backend based on config.run_events.backend
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
# Initialize run event store based on config
|
||||
app.state.run_event_store = _make_run_event_store(config)
|
||||
|
||||
app.state.run_event_store = MemoryRunEventStore()
|
||||
# RunManager with store backing for persistence
|
||||
app.state.run_manager = RunManager(store=app.state.run_store)
|
||||
|
||||
try:
|
||||
yield
|
||||
@@ -59,6 +55,32 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
await close_engine()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_run_store() -> RunStore:
|
||||
"""Create a RunStore: RunRepository if DB engine is available, else MemoryRunStore."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
|
||||
sf = get_session_factory()
|
||||
if sf is not None:
|
||||
from deerflow.persistence.repositories.run_repo import RunRepository
|
||||
|
||||
return RunRepository(sf)
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
return MemoryRunStore()
|
||||
|
||||
|
||||
def _make_run_event_store(config) -> RunEventStore:
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
run_events_config = getattr(config, "run_events", None)
|
||||
return make_run_event_store(run_events_config)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Getters -- called by routers per-request
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -19,7 +19,7 @@ from fastapi import APIRouter, HTTPException, Query, Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
|
||||
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.services import sse_consumer, start_run
|
||||
from deerflow.runtime import RunRecord, serialize_channel_values
|
||||
|
||||
@@ -263,3 +263,77 @@ async def stream_existing_run(
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Messages / Events / Token usage endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/{thread_id}/messages")
|
||||
async def list_thread_messages(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
limit: int = Query(default=50, le=200),
|
||||
before_seq: int | None = Query(default=None),
|
||||
after_seq: int | None = Query(default=None),
|
||||
) -> list[dict]:
|
||||
"""Return displayable messages for a thread (across all runs)."""
|
||||
event_store = get_run_event_store(request)
|
||||
return await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/messages")
|
||||
async def list_run_messages(thread_id: str, run_id: str, request: Request) -> list[dict]:
|
||||
"""Return displayable messages for a specific run."""
|
||||
event_store = get_run_event_store(request)
|
||||
return await event_store.list_messages_by_run(thread_id, run_id)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/events")
|
||||
async def list_run_events(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
event_types: str | None = Query(default=None),
|
||||
limit: int = Query(default=500, le=2000),
|
||||
) -> list[dict]:
|
||||
"""Return the full event stream for a run (debug/audit)."""
|
||||
event_store = get_run_event_store(request)
|
||||
types = event_types.split(",") if event_types else None
|
||||
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/token-usage")
|
||||
async def thread_token_usage(thread_id: str, request: Request) -> dict:
|
||||
"""Thread-level token usage aggregation."""
|
||||
run_store = get_run_store(request)
|
||||
runs = await run_store.list_by_thread(thread_id, limit=10000)
|
||||
completed = [r for r in runs if r.get("status") in ("success", "error")]
|
||||
|
||||
total_tokens = sum(r.get("total_tokens", 0) for r in completed)
|
||||
total_input = sum(r.get("total_input_tokens", 0) for r in completed)
|
||||
total_output = sum(r.get("total_output_tokens", 0) for r in completed)
|
||||
|
||||
by_model: dict[str, dict] = {}
|
||||
for r in completed:
|
||||
model = r.get("model_name") or "unknown"
|
||||
entry = by_model.setdefault(model, {"tokens": 0, "runs": 0})
|
||||
entry["tokens"] += r.get("total_tokens", 0)
|
||||
entry["runs"] += 1
|
||||
|
||||
by_caller = {
|
||||
"lead_agent": sum(r.get("lead_agent_tokens", 0) for r in completed),
|
||||
"subagent": sum(r.get("subagent_tokens", 0) for r in completed),
|
||||
"middleware": sum(r.get("middleware_tokens", 0) for r in completed),
|
||||
}
|
||||
|
||||
return {
|
||||
"thread_id": thread_id,
|
||||
"total_tokens": total_tokens,
|
||||
"total_input_tokens": total_input,
|
||||
"total_output_tokens": total_output,
|
||||
"total_runs": len(completed),
|
||||
"by_model": by_model,
|
||||
"by_caller": by_caller,
|
||||
}
|
||||
|
||||
@@ -63,6 +63,7 @@ class ThreadCreateRequest(BaseModel):
|
||||
"""Request body for creating a thread."""
|
||||
|
||||
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
|
||||
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Any
|
||||
from fastapi import HTTPException, Request
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.gateway.deps import get_checkpointer, get_run_manager, get_store, get_stream_bridge
|
||||
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_store, get_stream_bridge
|
||||
from deerflow.runtime import (
|
||||
END_SENTINEL,
|
||||
HEARTBEAT_SENTINEL,
|
||||
@@ -245,6 +245,12 @@ async def start_run(
|
||||
run_mgr = get_run_manager(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
store = get_store(request)
|
||||
event_store = get_run_event_store(request)
|
||||
|
||||
# Get run_events config for journal
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
run_events_config = getattr(get_app_config(), "run_events", None)
|
||||
|
||||
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
||||
|
||||
@@ -287,6 +293,8 @@ async def start_run(
|
||||
stream_subgraphs=body.stream_subgraphs,
|
||||
interrupt_before=body.interrupt_before,
|
||||
interrupt_after=body.interrupt_after,
|
||||
event_store=event_store,
|
||||
run_events_config=run_events_config,
|
||||
)
|
||||
)
|
||||
record.task = task
|
||||
|
||||
Reference in New Issue
Block a user