mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 16:35:59 +00:00
feat(persistence): add ORM models, repositories, DB/JSONL event stores, RunJournal, and API endpoints
Phase 2-B: run persistence + event storage + token tracking. - ORM models: RunRow (with token fields), ThreadMetaRow, RunEventRow - RunRepository implements RunStore ABC via SQLAlchemy ORM - ThreadMetaRepository with owner access control - DbRunEventStore with trace content truncation and cursor pagination - JsonlRunEventStore with per-run files and seq recovery from disk - RunJournal (BaseCallbackHandler) captures LLM/tool/lifecycle events, accumulates token usage by caller type, buffers and flushes to store - RunManager now accepts optional RunStore for persistent backing - Worker creates RunJournal, writes human_message, injects callbacks - Gateway deps use factory functions (RunRepository when DB available) - New endpoints: messages, run messages, run events, token-usage - ThreadCreateRequest gains assistant_id field - 92 tests pass (33 new), zero regressions Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
from deerflow.persistence.models.run import RunRow
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
from deerflow.persistence.models.thread_meta import ThreadMetaRow
|
||||
|
||||
__all__ = ["RunEventRow", "RunRow", "ThreadMetaRow"]
|
||||
@@ -0,0 +1,49 @@
|
||||
"""ORM model for run metadata."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import JSON, Index, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
class RunRow(Base):
|
||||
__tablename__ = "runs"
|
||||
|
||||
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128))
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="pending")
|
||||
# "pending" | "running" | "success" | "error" | "timeout" | "interrupted"
|
||||
|
||||
model_name: Mapped[str | None] = mapped_column(String(128))
|
||||
multitask_strategy: Mapped[str] = mapped_column(String(20), default="reject")
|
||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
kwargs_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
error: Mapped[str | None] = mapped_column(Text)
|
||||
|
||||
# Convenience fields (for listing pages without querying RunEventStore)
|
||||
message_count: Mapped[int] = mapped_column(default=0)
|
||||
first_human_message: Mapped[str | None] = mapped_column(Text)
|
||||
last_ai_message: Mapped[str | None] = mapped_column(Text)
|
||||
|
||||
# Token usage (accumulated in-memory by RunJournal, written on run completion)
|
||||
total_input_tokens: Mapped[int] = mapped_column(default=0)
|
||||
total_output_tokens: Mapped[int] = mapped_column(default=0)
|
||||
total_tokens: Mapped[int] = mapped_column(default=0)
|
||||
llm_call_count: Mapped[int] = mapped_column(default=0)
|
||||
lead_agent_tokens: Mapped[int] = mapped_column(default=0)
|
||||
subagent_tokens: Mapped[int] = mapped_column(default=0)
|
||||
middleware_tokens: Mapped[int] = mapped_column(default=0)
|
||||
|
||||
# Follow-up association
|
||||
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64))
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC))
|
||||
updated_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
|
||||
|
||||
__table_args__ = (Index("ix_runs_thread_status", "thread_id", "status"),)
|
||||
@@ -0,0 +1,30 @@
|
||||
"""ORM model for run events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import JSON, Index, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
class RunEventRow(Base):
|
||||
__tablename__ = "run_events"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
run_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
# "message" | "trace" | "lifecycle"
|
||||
content: Mapped[str] = mapped_column(Text, default="")
|
||||
event_metadata: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
seq: Mapped[int] = mapped_column(nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC))
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
|
||||
Index("ix_events_run", "thread_id", "run_id", "seq"),
|
||||
)
|
||||
@@ -0,0 +1,23 @@
|
||||
"""ORM model for thread metadata."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import JSON, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
class ThreadMetaRow(Base):
|
||||
__tablename__ = "threads_meta"
|
||||
|
||||
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128), index=True)
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
display_name: Mapped[str | None] = mapped_column(String(256))
|
||||
status: Mapped[str] = mapped_column(String(20), default="idle")
|
||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC))
|
||||
updated_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
|
||||
@@ -0,0 +1,4 @@
|
||||
from deerflow.persistence.repositories.run_repo import RunRepository
|
||||
from deerflow.persistence.repositories.thread_meta_repo import ThreadMetaRepository
|
||||
|
||||
__all__ = ["RunRepository", "ThreadMetaRepository"]
|
||||
@@ -0,0 +1,174 @@
|
||||
"""SQLAlchemy-backed RunStore implementation.
|
||||
|
||||
Each method acquires and releases its own short-lived session.
|
||||
Run status updates happen from background workers that may live
|
||||
minutes -- we don't hold connections across long execution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.run import RunRow
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunRepository(RunStore):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._sf = session_factory
|
||||
|
||||
@staticmethod
|
||||
def _safe_json(obj: Any) -> Any:
|
||||
"""Ensure obj is JSON-serializable. Falls back to model_dump() or str()."""
|
||||
if obj is None:
|
||||
return None
|
||||
if isinstance(obj, (str, int, float, bool)):
|
||||
return obj
|
||||
if isinstance(obj, dict):
|
||||
return {k: RunRepository._safe_json(v) for k, v in obj.items()}
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return [RunRepository._safe_json(v) for v in obj]
|
||||
if hasattr(obj, "model_dump"):
|
||||
try:
|
||||
return obj.model_dump()
|
||||
except Exception:
|
||||
pass
|
||||
if hasattr(obj, "dict"):
|
||||
try:
|
||||
return obj.dict()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
json.dumps(obj)
|
||||
return obj
|
||||
except (TypeError, ValueError):
|
||||
return str(obj)
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: RunRow) -> dict[str, Any]:
|
||||
d = row.to_dict()
|
||||
# Remap JSON columns to match RunStore interface
|
||||
d["metadata"] = d.pop("metadata_json", {})
|
||||
d["kwargs"] = d.pop("kwargs_json", {})
|
||||
# Convert datetime to ISO string for consistency with MemoryRunStore
|
||||
for key in ("created_at", "updated_at"):
|
||||
val = d.get(key)
|
||||
if isinstance(val, datetime):
|
||||
d[key] = val.isoformat()
|
||||
return d
|
||||
|
||||
async def put(
|
||||
self,
|
||||
run_id,
|
||||
*,
|
||||
thread_id,
|
||||
assistant_id=None,
|
||||
owner_id=None,
|
||||
status="pending",
|
||||
multitask_strategy="reject",
|
||||
metadata=None,
|
||||
kwargs=None,
|
||||
error=None,
|
||||
created_at=None,
|
||||
):
|
||||
now = datetime.now(UTC)
|
||||
row = RunRow(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
owner_id=owner_id,
|
||||
status=status,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata_json=self._safe_json(metadata) or {},
|
||||
kwargs_json=self._safe_json(kwargs) or {},
|
||||
error=error,
|
||||
created_at=datetime.fromisoformat(created_at) if created_at else now,
|
||||
updated_at=now,
|
||||
)
|
||||
async with self._sf() as session:
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
|
||||
async def get(self, run_id):
|
||||
async with self._sf() as session:
|
||||
row = await session.get(RunRow, run_id)
|
||||
return self._row_to_dict(row) if row else None
|
||||
|
||||
async def list_by_thread(self, thread_id, *, owner_id=None, limit=100):
|
||||
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
|
||||
if owner_id is not None:
|
||||
stmt = stmt.where(RunRow.owner_id == owner_id)
|
||||
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def update_status(self, run_id, status, *, error=None):
|
||||
values: dict[str, Any] = {"status": status, "updated_at": datetime.now(UTC)}
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
async with self._sf() as session:
|
||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
||||
await session.commit()
|
||||
|
||||
async def delete(self, run_id):
|
||||
async with self._sf() as session:
|
||||
row = await session.get(RunRow, run_id)
|
||||
if row is not None:
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
async def list_pending(self, *, before=None):
|
||||
now = before or datetime.now(UTC).isoformat()
|
||||
stmt = select(RunRow).where(RunRow.status == "pending", RunRow.created_at <= now).order_by(RunRow.created_at.asc())
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
status: str,
|
||||
total_input_tokens: int = 0,
|
||||
total_output_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
llm_call_count: int = 0,
|
||||
lead_agent_tokens: int = 0,
|
||||
subagent_tokens: int = 0,
|
||||
middleware_tokens: int = 0,
|
||||
message_count: int = 0,
|
||||
last_ai_message: str | None = None,
|
||||
first_human_message: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Update status + token usage + convenience fields on run completion."""
|
||||
values: dict[str, Any] = {
|
||||
"status": status,
|
||||
"total_input_tokens": total_input_tokens,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"llm_call_count": llm_call_count,
|
||||
"lead_agent_tokens": lead_agent_tokens,
|
||||
"subagent_tokens": subagent_tokens,
|
||||
"middleware_tokens": middleware_tokens,
|
||||
"message_count": message_count,
|
||||
"updated_at": datetime.now(UTC),
|
||||
}
|
||||
if last_ai_message is not None:
|
||||
values["last_ai_message"] = last_ai_message[:2000]
|
||||
if first_human_message is not None:
|
||||
values["first_human_message"] = first_human_message[:2000]
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
async with self._sf() as session:
|
||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
||||
await session.commit()
|
||||
@@ -0,0 +1,91 @@
|
||||
"""SQLAlchemy-backed thread metadata repository."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.thread_meta import ThreadMetaRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThreadMetaRepository:
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._sf = session_factory
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: ThreadMetaRow) -> dict[str, Any]:
|
||||
d = row.to_dict()
|
||||
d["metadata"] = d.pop("metadata_json", {})
|
||||
for key in ("created_at", "updated_at"):
|
||||
val = d.get(key)
|
||||
if isinstance(val, datetime):
|
||||
d[key] = val.isoformat()
|
||||
return d
|
||||
|
||||
async def create(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
owner_id: str | None = None,
|
||||
display_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
now = datetime.now(UTC)
|
||||
row = ThreadMetaRow(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
owner_id=owner_id,
|
||||
display_name=display_name,
|
||||
metadata_json=metadata or {},
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
async with self._sf() as session:
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def get(self, thread_id: str) -> dict | None:
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
return self._row_to_dict(row) if row else None
|
||||
|
||||
async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]:
|
||||
stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def check_access(self, thread_id: str, owner_id: str) -> bool:
|
||||
"""Check if owner_id has access to thread_id.
|
||||
|
||||
Returns True if: row doesn't exist (untracked thread), owner_id
|
||||
is None on the row (shared thread), or owner_id matches.
|
||||
"""
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return True
|
||||
if row.owner_id is None:
|
||||
return True
|
||||
return row.owner_id == owner_id
|
||||
|
||||
async def update_status(self, thread_id: str, status: str) -> None:
|
||||
async with self._sf() as session:
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
||||
await session.commit()
|
||||
|
||||
async def delete(self, thread_id: str) -> None:
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is not None:
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
@@ -1,4 +1,26 @@
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
__all__ = ["MemoryRunEventStore", "RunEventStore"]
|
||||
|
||||
def make_run_event_store(config=None) -> RunEventStore:
|
||||
"""Create a RunEventStore based on run_events.backend configuration."""
|
||||
if config is None or config.backend == "memory":
|
||||
return MemoryRunEventStore()
|
||||
if config.backend == "db":
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
# database.backend=memory but run_events.backend=db -> fallback
|
||||
return MemoryRunEventStore()
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
return DbRunEventStore(sf, max_trace_content=config.max_trace_content)
|
||||
if config.backend == "jsonl":
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
return JsonlRunEventStore()
|
||||
raise ValueError(f"Unknown run_events backend: {config.backend!r}")
|
||||
|
||||
|
||||
__all__ = ["MemoryRunEventStore", "RunEventStore", "make_run_event_store"]
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
"""SQLAlchemy-backed RunEventStore implementation.
|
||||
|
||||
Persists events to the ``run_events`` table. Trace content is truncated
|
||||
at ``max_trace_content`` bytes to avoid bloating the database.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
|
||||
|
||||
class DbRunEventStore(RunEventStore):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, max_trace_content: int = 10240):
|
||||
self._sf = session_factory
|
||||
self._max_trace_content = max_trace_content
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: RunEventRow) -> dict:
|
||||
d = row.to_dict()
|
||||
d["metadata"] = d.pop("event_metadata", {})
|
||||
val = d.get("created_at")
|
||||
if isinstance(val, datetime):
|
||||
d["created_at"] = val.isoformat()
|
||||
d.pop("id", None)
|
||||
return d
|
||||
|
||||
def _truncate_trace(self, category: str, content: str, metadata: dict | None) -> tuple[str, dict]:
|
||||
if category == "trace" and len(content) > self._max_trace_content:
|
||||
content = content[: self._max_trace_content]
|
||||
metadata = {**(metadata or {}), "content_truncated": True}
|
||||
return content, metadata or {}
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None):
|
||||
content, metadata = self._truncate_trace(category, content, metadata)
|
||||
async with self._sf() as session:
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id))
|
||||
seq = (max_seq or 0) + 1
|
||||
row = RunEventRow(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
event_type=event_type,
|
||||
category=category,
|
||||
content=content,
|
||||
event_metadata=metadata,
|
||||
seq=seq,
|
||||
created_at=datetime.fromisoformat(created_at) if created_at else datetime.now(UTC),
|
||||
)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def put_batch(self, events):
|
||||
if not events:
|
||||
return []
|
||||
async with self._sf() as session:
|
||||
# Get max seq for the thread (assume all events in batch belong to same thread)
|
||||
thread_id = events[0]["thread_id"]
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id))
|
||||
seq = max_seq or 0
|
||||
rows = []
|
||||
for e in events:
|
||||
seq += 1
|
||||
content = e.get("content", "")
|
||||
category = e.get("category", "trace")
|
||||
metadata = e.get("metadata")
|
||||
content, metadata = self._truncate_trace(category, content, metadata)
|
||||
row = RunEventRow(
|
||||
thread_id=e["thread_id"],
|
||||
run_id=e["run_id"],
|
||||
event_type=e["event_type"],
|
||||
category=category,
|
||||
content=content,
|
||||
event_metadata=metadata,
|
||||
seq=seq,
|
||||
created_at=datetime.fromisoformat(e["created_at"]) if e.get("created_at") else datetime.now(UTC),
|
||||
)
|
||||
session.add(row)
|
||||
rows.append(row)
|
||||
await session.commit()
|
||||
for row in rows:
|
||||
await session.refresh(row)
|
||||
return [self._row_to_dict(r) for r in rows]
|
||||
|
||||
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq < before_seq)
|
||||
if after_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq > after_seq)
|
||||
|
||||
if after_seq is not None:
|
||||
# Forward pagination: first `limit` records after cursor
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
else:
|
||||
# before_seq or default (latest): take last `limit` records, return ascending
|
||||
stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
rows = list(result.scalars())
|
||||
return [self._row_to_dict(r) for r in reversed(rows)]
|
||||
|
||||
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
||||
if event_types:
|
||||
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def list_messages_by_run(self, thread_id, run_id):
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message").order_by(RunEventRow.seq.asc())
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def count_messages(self, thread_id):
|
||||
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
async with self._sf() as session:
|
||||
return await session.scalar(stmt) or 0
|
||||
|
||||
async def delete_by_thread(self, thread_id):
|
||||
async with self._sf() as session:
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id))
|
||||
await session.commit()
|
||||
return count
|
||||
|
||||
async def delete_by_run(self, thread_id, run_id):
|
||||
async with self._sf() as session:
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id))
|
||||
await session.commit()
|
||||
return count
|
||||
@@ -0,0 +1,164 @@
|
||||
"""JSONL file-backed RunEventStore implementation.
|
||||
|
||||
Each run's events are stored in a single file:
|
||||
``.deer-flow/threads/{thread_id}/runs/{run_id}.jsonl``
|
||||
|
||||
All categories (message, trace, lifecycle) are in the same file.
|
||||
This backend is suitable for lightweight single-node deployments.
|
||||
|
||||
Known trade-off: ``list_messages()`` must scan all run files for a
|
||||
thread since messages from multiple runs need unified seq ordering.
|
||||
``list_events()`` reads only one file -- the fast path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JsonlRunEventStore(RunEventStore):
|
||||
def __init__(self, base_dir: str | Path | None = None):
|
||||
self._base_dir = Path(base_dir) if base_dir else Path(".deer-flow")
|
||||
self._seq_counters: dict[str, int] = {} # thread_id -> current max seq
|
||||
|
||||
def _thread_dir(self, thread_id: str) -> Path:
|
||||
return self._base_dir / "threads" / thread_id / "runs"
|
||||
|
||||
def _run_file(self, thread_id: str, run_id: str) -> Path:
|
||||
return self._thread_dir(thread_id) / f"{run_id}.jsonl"
|
||||
|
||||
def _next_seq(self, thread_id: str) -> int:
|
||||
self._seq_counters[thread_id] = self._seq_counters.get(thread_id, 0) + 1
|
||||
return self._seq_counters[thread_id]
|
||||
|
||||
def _ensure_seq_loaded(self, thread_id: str) -> None:
|
||||
"""Load max seq from existing files if not yet cached."""
|
||||
if thread_id in self._seq_counters:
|
||||
return
|
||||
max_seq = 0
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if thread_dir.exists():
|
||||
for f in thread_dir.glob("*.jsonl"):
|
||||
for line in f.read_text(encoding="utf-8").strip().splitlines():
|
||||
try:
|
||||
record = json.loads(line)
|
||||
max_seq = max(max_seq, record.get("seq", 0))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
self._seq_counters[thread_id] = max_seq
|
||||
|
||||
def _write_record(self, record: dict) -> None:
|
||||
path = self._run_file(record["thread_id"], record["run_id"])
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(record, default=str, ensure_ascii=False) + "\n")
|
||||
|
||||
def _read_thread_events(self, thread_id: str) -> list[dict]:
|
||||
"""Read all events for a thread, sorted by seq."""
|
||||
events = []
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if not thread_dir.exists():
|
||||
return events
|
||||
for f in sorted(thread_dir.glob("*.jsonl")):
|
||||
for line in f.read_text(encoding="utf-8").strip().splitlines():
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
events.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
events.sort(key=lambda e: e.get("seq", 0))
|
||||
return events
|
||||
|
||||
def _read_run_events(self, thread_id: str, run_id: str) -> list[dict]:
|
||||
"""Read events for a specific run file."""
|
||||
path = self._run_file(thread_id, run_id)
|
||||
if not path.exists():
|
||||
return []
|
||||
events = []
|
||||
for line in path.read_text(encoding="utf-8").strip().splitlines():
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
events.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
events.sort(key=lambda e: e.get("seq", 0))
|
||||
return events
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None):
|
||||
self._ensure_seq_loaded(thread_id)
|
||||
seq = self._next_seq(thread_id)
|
||||
record = {
|
||||
"thread_id": thread_id,
|
||||
"run_id": run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"seq": seq,
|
||||
"created_at": created_at or datetime.now(UTC).isoformat(),
|
||||
}
|
||||
self._write_record(record)
|
||||
return record
|
||||
|
||||
async def put_batch(self, events):
|
||||
if not events:
|
||||
return []
|
||||
results = []
|
||||
for ev in events:
|
||||
record = await self.put(**ev)
|
||||
results.append(record)
|
||||
return results
|
||||
|
||||
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
messages = [e for e in all_events if e.get("category") == "message"]
|
||||
|
||||
if before_seq is not None:
|
||||
messages = [e for e in messages if e["seq"] < before_seq]
|
||||
return messages[-limit:]
|
||||
elif after_seq is not None:
|
||||
messages = [e for e in messages if e["seq"] > after_seq]
|
||||
return messages[:limit]
|
||||
else:
|
||||
return messages[-limit:]
|
||||
|
||||
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
if event_types is not None:
|
||||
events = [e for e in events if e.get("event_type") in event_types]
|
||||
return events[:limit]
|
||||
|
||||
async def list_messages_by_run(self, thread_id, run_id):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
return [e for e in events if e.get("category") == "message"]
|
||||
|
||||
async def count_messages(self, thread_id):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
return sum(1 for e in all_events if e.get("category") == "message")
|
||||
|
||||
async def delete_by_thread(self, thread_id):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
count = len(all_events)
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if thread_dir.exists():
|
||||
for f in thread_dir.glob("*.jsonl"):
|
||||
f.unlink()
|
||||
self._seq_counters.pop(thread_id, None)
|
||||
return count
|
||||
|
||||
async def delete_by_run(self, thread_id, run_id):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
count = len(events)
|
||||
path = self._run_file(thread_id, run_id)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
return count
|
||||
@@ -0,0 +1,333 @@
|
||||
"""Run event capture via LangChain callbacks.
|
||||
|
||||
RunJournal sits between LangChain's callback mechanism and the pluggable
|
||||
RunEventStore. It standardizes callback data into RunEvent records and
|
||||
handles token usage accumulation.
|
||||
|
||||
Key design decisions:
|
||||
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
|
||||
- All LangChain objects serialized via serialize_lc_object (same as worker.py SSE)
|
||||
- Token usage accumulated in memory, written to RunRow on run completion
|
||||
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunJournal(BaseCallbackHandler):
|
||||
"""LangChain callback handler that captures events to RunEventStore."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
event_store: RunEventStore,
|
||||
*,
|
||||
track_token_usage: bool = True,
|
||||
on_complete: Callable[..., Any] | None = None,
|
||||
flush_threshold: int = 20,
|
||||
):
|
||||
super().__init__()
|
||||
self.run_id = run_id
|
||||
self.thread_id = thread_id
|
||||
self._store = event_store
|
||||
self._track_tokens = track_token_usage
|
||||
self._on_complete = on_complete
|
||||
self._flush_threshold = flush_threshold
|
||||
|
||||
# Write buffer
|
||||
self._buffer: list[dict] = []
|
||||
|
||||
# Token accumulators
|
||||
self._total_input_tokens = 0
|
||||
self._total_output_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._llm_call_count = 0
|
||||
self._lead_agent_tokens = 0
|
||||
self._subagent_tokens = 0
|
||||
self._middleware_tokens = 0
|
||||
|
||||
# Convenience fields
|
||||
self._last_ai_msg: str | None = None
|
||||
self._first_human_msg: str | None = None
|
||||
self._msg_count = 0
|
||||
|
||||
# Latency tracking
|
||||
self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time
|
||||
|
||||
# -- Lifecycle callbacks --
|
||||
|
||||
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
# Only record for the top-level chain (parent_run_id is None)
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(
|
||||
event_type="run_start",
|
||||
category="lifecycle",
|
||||
metadata={"input_preview": str(inputs)[:500]},
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
|
||||
self._flush_sync()
|
||||
if self._on_complete:
|
||||
self._on_complete(
|
||||
total_input_tokens=self._total_input_tokens,
|
||||
total_output_tokens=self._total_output_tokens,
|
||||
total_tokens=self._total_tokens,
|
||||
llm_call_count=self._llm_call_count,
|
||||
lead_agent_tokens=self._lead_agent_tokens,
|
||||
subagent_tokens=self._subagent_tokens,
|
||||
middleware_tokens=self._middleware_tokens,
|
||||
message_count=self._msg_count,
|
||||
last_ai_message=self._last_ai_msg,
|
||||
first_human_message=self._first_human_msg,
|
||||
)
|
||||
|
||||
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(
|
||||
event_type="run_error",
|
||||
category="lifecycle",
|
||||
content=str(error),
|
||||
metadata={"error_type": type(error).__name__},
|
||||
)
|
||||
self._flush_sync()
|
||||
|
||||
# -- LLM callbacks --
|
||||
|
||||
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._llm_start_times[str(run_id)] = time.monotonic()
|
||||
self._put(
|
||||
event_type="llm_start",
|
||||
category="trace",
|
||||
metadata={"model_name": serialized.get("name", "")},
|
||||
)
|
||||
|
||||
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
|
||||
try:
|
||||
message = response.generations[0][0].message
|
||||
except (IndexError, AttributeError):
|
||||
logger.debug("on_llm_end: could not extract message from response")
|
||||
return
|
||||
|
||||
serialized_msg = serialize_lc_object(message)
|
||||
caller = self._identify_caller(kwargs)
|
||||
|
||||
# Latency
|
||||
start = self._llm_start_times.pop(str(run_id), None)
|
||||
latency_ms = int((time.monotonic() - start) * 1000) if start else None
|
||||
|
||||
# Token usage from message
|
||||
usage = getattr(message, "usage_metadata", None)
|
||||
usage_dict = dict(usage) if usage else {}
|
||||
|
||||
# trace event: llm_end (every LLM call)
|
||||
self._put(
|
||||
event_type="llm_end",
|
||||
category="trace",
|
||||
content=getattr(message, "content", "") if isinstance(getattr(message, "content", ""), str) else str(getattr(message, "content", "")),
|
||||
metadata={
|
||||
"message": serialized_msg,
|
||||
"caller": caller,
|
||||
"usage": usage_dict,
|
||||
"latency_ms": latency_ms,
|
||||
},
|
||||
)
|
||||
|
||||
# message event: ai_message (only lead_agent final replies with content)
|
||||
if caller == "lead_agent":
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, str) and content:
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
tool_calls_summary = [{"name": tc.get("name", ""), "status": "success"} for tc in tool_calls if isinstance(tc, dict)]
|
||||
resp_meta = getattr(message, "response_metadata", None) or {}
|
||||
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
|
||||
self._put(
|
||||
event_type="ai_message",
|
||||
category="message",
|
||||
content=content,
|
||||
metadata={
|
||||
"model_name": model_name,
|
||||
"tool_calls": tool_calls_summary,
|
||||
},
|
||||
)
|
||||
self._last_ai_msg = content[:2000]
|
||||
self._msg_count += 1
|
||||
|
||||
# Token accumulation
|
||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||
if self._track_tokens and total_tk > 0:
|
||||
self._total_input_tokens += input_tk
|
||||
self._total_output_tokens += output_tk
|
||||
self._total_tokens += total_tk
|
||||
self._llm_call_count += 1
|
||||
if caller.startswith("subagent:"):
|
||||
self._subagent_tokens += total_tk
|
||||
elif caller.startswith("middleware:"):
|
||||
self._middleware_tokens += total_tk
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
|
||||
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._llm_start_times.pop(str(run_id), None)
|
||||
self._put(event_type="llm_error", category="trace", content=str(error))
|
||||
|
||||
# -- Tool callbacks --
|
||||
|
||||
def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._put(
|
||||
event_type="tool_start",
|
||||
category="trace",
|
||||
metadata={
|
||||
"tool_name": serialized.get("name", ""),
|
||||
"tool_call_id": kwargs.get("tool_call_id"),
|
||||
"args": str(input_str)[:2000],
|
||||
},
|
||||
)
|
||||
|
||||
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._put(
|
||||
event_type="tool_end",
|
||||
category="trace",
|
||||
content=str(output),
|
||||
metadata={
|
||||
"tool_name": kwargs.get("name", ""),
|
||||
"tool_call_id": kwargs.get("tool_call_id"),
|
||||
"status": "success",
|
||||
},
|
||||
)
|
||||
|
||||
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._put(
|
||||
event_type="tool_error",
|
||||
category="trace",
|
||||
content=str(error),
|
||||
metadata={
|
||||
"tool_name": kwargs.get("name", ""),
|
||||
"tool_call_id": kwargs.get("tool_call_id"),
|
||||
},
|
||||
)
|
||||
|
||||
# -- Custom event callback --
|
||||
|
||||
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
|
||||
if name == "summarization":
|
||||
data_dict = data if isinstance(data, dict) else {}
|
||||
self._put(
|
||||
event_type="summarization",
|
||||
category="trace",
|
||||
content=data_dict.get("summary", ""),
|
||||
metadata={
|
||||
"replaced_message_ids": data_dict.get("replaced_message_ids", []),
|
||||
"replaced_count": data_dict.get("replaced_count", 0),
|
||||
},
|
||||
)
|
||||
self._put(
|
||||
event_type="summary",
|
||||
category="message",
|
||||
content=data_dict.get("summary", ""),
|
||||
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
|
||||
)
|
||||
else:
|
||||
event_data = serialize_lc_object(data) if not isinstance(data, dict) else data
|
||||
self._put(
|
||||
event_type=name,
|
||||
category="trace",
|
||||
metadata=event_data if isinstance(event_data, dict) else {"data": event_data},
|
||||
)
|
||||
|
||||
# -- Internal methods --
|
||||
|
||||
def _put(self, *, event_type: str, category: str, content: str = "", metadata: dict | None = None) -> None:
|
||||
self._buffer.append({
|
||||
"thread_id": self.thread_id,
|
||||
"run_id": self.run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
})
|
||||
if len(self._buffer) >= self._flush_threshold:
|
||||
self._flush_sync()
|
||||
|
||||
def _flush_sync(self) -> None:
|
||||
"""Flush buffer to RunEventStore.
|
||||
|
||||
BaseCallbackHandler methods are synchronous. We schedule the async
|
||||
put_batch via the current event loop.
|
||||
"""
|
||||
if not self._buffer:
|
||||
return
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self._flush_async(batch))
|
||||
except RuntimeError:
|
||||
logger.warning("RunJournal: no event loop, dropping %d events", len(batch))
|
||||
|
||||
async def _flush_async(self, batch: list[dict]) -> None:
|
||||
try:
|
||||
await self._store.put_batch(batch)
|
||||
except Exception:
|
||||
logger.warning("RunJournal: failed to flush %d events", len(batch), exc_info=True)
|
||||
|
||||
def _identify_caller(self, kwargs: dict) -> str:
|
||||
for tag in kwargs.get("tags") or []:
|
||||
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
|
||||
return tag
|
||||
return "unknown"
|
||||
|
||||
# -- Public methods (called by worker) --
|
||||
|
||||
def set_first_human_message(self, content: str) -> None:
|
||||
"""Record the first human message for convenience fields."""
|
||||
self._first_human_msg = content[:2000] if content else None
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Force flush. Used in cancel/error paths."""
|
||||
if self._buffer:
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
await self._store.put_batch(batch)
|
||||
|
||||
def get_completion_data(self) -> dict:
|
||||
"""Return accumulated token and message data for run completion."""
|
||||
return {
|
||||
"total_input_tokens": self._total_input_tokens,
|
||||
"total_output_tokens": self._total_output_tokens,
|
||||
"total_tokens": self._total_tokens,
|
||||
"llm_call_count": self._llm_call_count,
|
||||
"lead_agent_tokens": self._lead_agent_tokens,
|
||||
"subagent_tokens": self._subagent_tokens,
|
||||
"middleware_tokens": self._middleware_tokens,
|
||||
"message_count": self._msg_count,
|
||||
"last_ai_message": self._last_ai_msg,
|
||||
"first_human_message": self._first_human_msg,
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
"""In-memory run registry."""
|
||||
"""In-memory run registry with optional persistent RunStore backing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -7,9 +7,13 @@ import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .schemas import DisconnectMode, RunStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -38,11 +42,17 @@ class RunRecord:
|
||||
|
||||
|
||||
class RunManager:
|
||||
"""In-memory run registry. All mutations are protected by an asyncio lock."""
|
||||
"""In-memory run registry with optional persistent RunStore backing.
|
||||
|
||||
def __init__(self) -> None:
|
||||
All mutations are protected by an asyncio lock. When a ``store`` is
|
||||
provided, serializable metadata is also persisted to the store so
|
||||
that run history survives process restarts.
|
||||
"""
|
||||
|
||||
def __init__(self, store: RunStore | None = None) -> None:
|
||||
self._runs: dict[str, RunRecord] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._store = store
|
||||
|
||||
async def create(
|
||||
self,
|
||||
@@ -71,6 +81,20 @@ class RunManager:
|
||||
)
|
||||
async with self._lock:
|
||||
self._runs[run_id] = record
|
||||
if self._store is not None:
|
||||
try:
|
||||
await self._store.put(
|
||||
run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
status=RunStatus.pending.value,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata=metadata or {},
|
||||
kwargs=kwargs or {},
|
||||
created_at=now,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run %s to store", run_id, exc_info=True)
|
||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||
return record
|
||||
|
||||
@@ -96,6 +120,11 @@ class RunManager:
|
||||
record.updated_at = _now_iso()
|
||||
if error is not None:
|
||||
record.error = error
|
||||
if self._store is not None:
|
||||
try:
|
||||
await self._store.update_status(run_id, status.value, error=error)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist status update for run %s", run_id, exc_info=True)
|
||||
logger.info("Run %s -> %s", run_id, status.value)
|
||||
|
||||
async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool:
|
||||
@@ -185,6 +214,21 @@ class RunManager:
|
||||
)
|
||||
self._runs[run_id] = record
|
||||
|
||||
if self._store is not None:
|
||||
try:
|
||||
await self._store.put(
|
||||
run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
status=RunStatus.pending.value,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata=metadata or {},
|
||||
kwargs=kwargs or {},
|
||||
created_at=now,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run %s to store", run_id, exc_info=True)
|
||||
|
||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||
return record
|
||||
|
||||
|
||||
@@ -45,6 +45,8 @@ async def run_agent(
|
||||
stream_subgraphs: bool = False,
|
||||
interrupt_before: list[str] | Literal["*"] | None = None,
|
||||
interrupt_after: list[str] | Literal["*"] | None = None,
|
||||
event_store: Any | None = None,
|
||||
run_events_config: Any | None = None,
|
||||
) -> None:
|
||||
"""Execute an agent in the background, publishing events to *bridge*."""
|
||||
|
||||
@@ -52,6 +54,30 @@ async def run_agent(
|
||||
thread_id = record.thread_id
|
||||
requested_modes: set[str] = set(stream_modes or ["values"])
|
||||
|
||||
# Initialize RunJournal for event capture
|
||||
journal = None
|
||||
if event_store is not None:
|
||||
from deerflow.runtime.journal import RunJournal
|
||||
|
||||
journal = RunJournal(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
event_store=event_store,
|
||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||
)
|
||||
|
||||
# Write human_message event
|
||||
user_input = _extract_user_input(graph_input)
|
||||
if user_input:
|
||||
await event_store.put(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content=user_input,
|
||||
)
|
||||
journal.set_first_human_message(user_input)
|
||||
|
||||
# Track whether "events" was requested but skipped
|
||||
if "events" in requested_modes:
|
||||
logger.info(
|
||||
@@ -92,6 +118,10 @@ async def run_agent(
|
||||
runtime = Runtime(context={"thread_id": thread_id}, store=store)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
# Inject RunJournal as a callback
|
||||
if journal is not None:
|
||||
config.setdefault("callbacks", []).append(journal)
|
||||
|
||||
runnable_config = RunnableConfig(**config)
|
||||
agent = agent_factory(config=runnable_config)
|
||||
|
||||
@@ -206,6 +236,13 @@ async def run_agent(
|
||||
)
|
||||
|
||||
finally:
|
||||
# Flush any buffered journal events
|
||||
if journal is not None:
|
||||
try:
|
||||
await journal.flush()
|
||||
except Exception:
|
||||
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
|
||||
|
||||
await bridge.publish_end(run_id)
|
||||
asyncio.create_task(bridge.cleanup(run_id, delay=60))
|
||||
|
||||
@@ -227,6 +264,23 @@ def _lg_mode_to_sse_event(mode: str) -> str:
|
||||
return mode
|
||||
|
||||
|
||||
def _extract_user_input(graph_input: dict) -> str:
|
||||
"""Extract user input text from graph_input for event recording."""
|
||||
messages = graph_input.get("messages")
|
||||
if not messages:
|
||||
return ""
|
||||
# Take the last message (usually the user's input)
|
||||
last = messages[-1] if isinstance(messages, list) else messages
|
||||
if isinstance(last, str):
|
||||
return last
|
||||
if hasattr(last, "content"):
|
||||
content = last.content
|
||||
return content if isinstance(content, str) else str(content)
|
||||
if isinstance(last, dict):
|
||||
return str(last.get("content", ""))
|
||||
return ""
|
||||
|
||||
|
||||
def _unpack_stream_item(
|
||||
item: Any,
|
||||
lg_modes: list[str],
|
||||
|
||||
Reference in New Issue
Block a user