feat(persistence): add RunEventStore ABC + MemoryRunEventStore

Phase 2-A prerequisite for event storage: adds the unified run event
stream interface (RunEventStore) with an in-memory implementation,
RunEventsConfig, gateway integration, and comprehensive tests (27 cases).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-02 14:23:13 +08:00
parent 1ff6b5f7ab
commit 23eacf9533
9 changed files with 563 additions and 0 deletions
+15
View File
@@ -14,6 +14,7 @@ from contextlib import AsyncExitStack, asynccontextmanager
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request
from deerflow.runtime import RunManager, StreamBridge from deerflow.runtime import RunManager, StreamBridge
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.runs.store.base import RunStore from deerflow.runtime.runs.store.base import RunStore
@@ -46,6 +47,12 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
# RunRepository when models are implemented) # RunRepository when models are implemented)
app.state.run_store = MemoryRunStore() app.state.run_store = MemoryRunStore()
# Initialize run event store (MemoryRunEventStore for now)
# TODO(Phase 2-B): switch to db/jsonl backend based on config.run_events.backend
from deerflow.runtime.events.store.memory import MemoryRunEventStore
app.state.run_event_store = MemoryRunEventStore()
try: try:
yield yield
finally: finally:
@@ -86,6 +93,14 @@ def get_store(request: Request):
return getattr(request.app.state, "store", None) return getattr(request.app.state, "store", None)
def get_run_event_store(request: Request) -> RunEventStore:
"""Return the RunEventStore, or 503 if not available."""
store = getattr(request.app.state, "run_event_store", None)
if store is None:
raise HTTPException(status_code=503, detail="Run event store not available")
return store
def get_run_store(request: Request) -> RunStore: def get_run_store(request: Request) -> RunStore:
"""Return the RunStore, or 503 if not available.""" """Return the RunStore, or 503 if not available."""
store = getattr(request.app.state, "run_store", None) store = getattr(request.app.state, "run_store", None)
@@ -14,6 +14,7 @@ from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import load_guardrails_config_from_dict from deerflow.config.guardrails_config import load_guardrails_config_from_dict
from deerflow.config.memory_config import load_memory_config_from_dict from deerflow.config.memory_config import load_memory_config_from_dict
from deerflow.config.model_config import ModelConfig from deerflow.config.model_config import ModelConfig
from deerflow.config.run_events_config import RunEventsConfig
from deerflow.config.sandbox_config import SandboxConfig from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.skills_config import SkillsConfig from deerflow.config.skills_config import SkillsConfig
from deerflow.config.stream_bridge_config import StreamBridgeConfig, load_stream_bridge_config_from_dict from deerflow.config.stream_bridge_config import StreamBridgeConfig, load_stream_bridge_config_from_dict
@@ -43,6 +44,7 @@ class AppConfig(BaseModel):
tool_search: ToolSearchConfig = Field(default_factory=ToolSearchConfig, description="Tool search / deferred loading configuration") tool_search: ToolSearchConfig = Field(default_factory=ToolSearchConfig, description="Tool search / deferred loading configuration")
model_config = ConfigDict(extra="allow", frozen=False) model_config = ConfigDict(extra="allow", frozen=False)
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration") database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration") checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration") stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
@@ -0,0 +1,33 @@
"""Run event storage configuration.
Controls where run events (messages + execution traces) are persisted.
Backends:
- memory: In-memory storage, data lost on restart. Suitable for
development and testing.
- db: SQL database via SQLAlchemy ORM. Provides full query capability.
Suitable for production deployments.
- jsonl: Append-only JSONL files. Lightweight alternative for
single-node deployments that need persistence without a database.
"""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, Field
class RunEventsConfig(BaseModel):
backend: Literal["memory", "db", "jsonl"] = Field(
default="memory",
description="Storage backend for run events. 'memory' for development (no persistence), 'db' for production (SQL queries), 'jsonl' for lightweight single-node persistence.",
)
max_trace_content: int = Field(
default=10240,
description="Maximum trace content size in bytes before truncation (db backend only).",
)
track_token_usage: bool = Field(
default=True,
description="Whether RunJournal should accumulate token counts to RunRow.",
)
@@ -0,0 +1,4 @@
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.events.store.memory import MemoryRunEventStore
__all__ = ["MemoryRunEventStore", "RunEventStore"]
@@ -0,0 +1,4 @@
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.events.store.memory import MemoryRunEventStore
__all__ = ["MemoryRunEventStore", "RunEventStore"]
@@ -0,0 +1,99 @@
"""Abstract interface for run event storage.
RunEventStore is the unified storage interface for run event streams.
Messages (frontend display) and execution traces (debugging/audit) go
through the same interface, distinguished by the ``category`` field.
Implementations:
- MemoryRunEventStore: in-memory dict (development, tests)
- Future: DB-backed store (SQLAlchemy ORM), JSONL file store
"""
from __future__ import annotations
import abc
class RunEventStore(abc.ABC):
"""Run event stream storage interface.
All implementations must guarantee:
1. put() events are retrievable in subsequent queries
2. seq is strictly increasing within the same thread
3. list_messages() only returns category="message" events
4. list_events() returns all events for the specified run
5. Returned dicts match the RunEvent field structure
"""
@abc.abstractmethod
async def put(
self,
*,
thread_id: str,
run_id: str,
event_type: str,
category: str,
content: str = "",
metadata: dict | None = None,
created_at: str | None = None,
) -> dict:
"""Write an event, auto-assign seq, return the complete record."""
@abc.abstractmethod
async def put_batch(self, events: list[dict]) -> list[dict]:
"""Batch-write events. Used by RunJournal flush buffer.
Each dict's keys match put()'s keyword arguments.
Returns complete records with seq assigned.
"""
@abc.abstractmethod
async def list_messages(
self,
thread_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
) -> list[dict]:
"""Return displayable messages (category=message) for a thread, ordered by seq ascending.
Supports bidirectional cursor pagination:
- before_seq: return the last ``limit`` records with seq < before_seq (ascending)
- after_seq: return the first ``limit`` records with seq > after_seq (ascending)
- neither: return the latest ``limit`` records (ascending)
"""
@abc.abstractmethod
async def list_events(
self,
thread_id: str,
run_id: str,
*,
event_types: list[str] | None = None,
limit: int = 500,
) -> list[dict]:
"""Return the full event stream for a run, ordered by seq ascending.
Optionally filter by event_types.
"""
@abc.abstractmethod
async def list_messages_by_run(
self,
thread_id: str,
run_id: str,
) -> list[dict]:
"""Return displayable messages (category=message) for a specific run, ordered by seq ascending."""
@abc.abstractmethod
async def count_messages(self, thread_id: str) -> int:
"""Count displayable messages (category=message) in a thread."""
@abc.abstractmethod
async def delete_by_thread(self, thread_id: str) -> int:
"""Delete all events for a thread. Return the number of deleted events."""
@abc.abstractmethod
async def delete_by_run(self, thread_id: str, run_id: str) -> int:
"""Delete all events for a specific run. Return the number of deleted events."""
@@ -0,0 +1,120 @@
"""In-memory RunEventStore. Used when run_events.backend=memory (default) and in tests.
Thread-safe for single-process async usage (no threading locks needed
since all mutations happen within the same event loop).
"""
from __future__ import annotations
from datetime import UTC, datetime
from deerflow.runtime.events.store.base import RunEventStore
class MemoryRunEventStore(RunEventStore):
def __init__(self) -> None:
self._events: dict[str, list[dict]] = {} # thread_id -> sorted event list
self._seq_counters: dict[str, int] = {} # thread_id -> last assigned seq
def _next_seq(self, thread_id: str) -> int:
current = self._seq_counters.get(thread_id, 0)
next_val = current + 1
self._seq_counters[thread_id] = next_val
return next_val
def _put_one(
self,
*,
thread_id: str,
run_id: str,
event_type: str,
category: str,
content: str = "",
metadata: dict | None = None,
created_at: str | None = None,
) -> dict:
seq = self._next_seq(thread_id)
record = {
"thread_id": thread_id,
"run_id": run_id,
"event_type": event_type,
"category": category,
"content": content,
"metadata": metadata or {},
"seq": seq,
"created_at": created_at or datetime.now(UTC).isoformat(),
}
self._events.setdefault(thread_id, []).append(record)
return record
async def put(
self,
*,
thread_id,
run_id,
event_type,
category,
content="",
metadata=None,
created_at=None,
):
return self._put_one(
thread_id=thread_id,
run_id=run_id,
event_type=event_type,
category=category,
content=content,
metadata=metadata,
created_at=created_at,
)
async def put_batch(self, events):
results = []
for ev in events:
record = self._put_one(**ev)
results.append(record)
return results
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
all_events = self._events.get(thread_id, [])
messages = [e for e in all_events if e["category"] == "message"]
if before_seq is not None:
messages = [e for e in messages if e["seq"] < before_seq]
# Take the last `limit` records
return messages[-limit:]
elif after_seq is not None:
messages = [e for e in messages if e["seq"] > after_seq]
return messages[:limit]
else:
# Return the latest `limit` records, ascending
return messages[-limit:]
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
all_events = self._events.get(thread_id, [])
filtered = [e for e in all_events if e["run_id"] == run_id]
if event_types is not None:
filtered = [e for e in filtered if e["event_type"] in event_types]
return filtered[:limit]
async def list_messages_by_run(self, thread_id, run_id):
all_events = self._events.get(thread_id, [])
return [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"]
async def count_messages(self, thread_id):
all_events = self._events.get(thread_id, [])
return sum(1 for e in all_events if e["category"] == "message")
async def delete_by_thread(self, thread_id):
events = self._events.pop(thread_id, [])
self._seq_counters.pop(thread_id, None)
return len(events)
async def delete_by_run(self, thread_id, run_id):
all_events = self._events.get(thread_id, [])
if not all_events:
return 0
remaining = [e for e in all_events if e["run_id"] != run_id]
removed = len(all_events) - len(remaining)
self._events[thread_id] = remaining
return removed
+272
View File
@@ -0,0 +1,272 @@
"""Tests for RunEventStore ABC + MemoryRunEventStore.
Covers:
- Basic write and query (put, seq assignment, cross-thread independence)
- list_messages (category filtering, pagination, cross-run ordering)
- list_events (run filtering, event_types filtering)
- list_messages_by_run
- count_messages
- put_batch
- delete_by_thread, delete_by_run
- Edge cases (empty thread/run)
"""
import pytest
from deerflow.runtime.events.store.memory import MemoryRunEventStore
@pytest.fixture
def store():
return MemoryRunEventStore()
# -- Basic write and query --
class TestPutAndSeq:
@pytest.mark.anyio
async def test_put_returns_dict_with_seq(self, store):
record = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hello")
assert "seq" in record
assert record["seq"] == 1
assert record["thread_id"] == "t1"
assert record["run_id"] == "r1"
assert record["event_type"] == "human_message"
assert record["category"] == "message"
assert record["content"] == "hello"
assert record["metadata"] == {}
assert "created_at" in record
@pytest.mark.anyio
async def test_seq_strictly_increasing_same_thread(self, store):
r1 = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
r2 = await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
r3 = await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
assert r1["seq"] == 1
assert r2["seq"] == 2
assert r3["seq"] == 3
@pytest.mark.anyio
async def test_seq_independent_across_threads(self, store):
r1 = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
r2 = await store.put(thread_id="t2", run_id="r2", event_type="human_message", category="message")
assert r1["seq"] == 1
assert r2["seq"] == 1
@pytest.mark.anyio
async def test_put_respects_provided_created_at(self, store):
ts = "2024-06-01T12:00:00+00:00"
record = await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", created_at=ts)
assert record["created_at"] == ts
@pytest.mark.anyio
async def test_put_metadata_preserved(self, store):
meta = {"model": "gpt-4", "tokens": 100}
record = await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace", metadata=meta)
assert record["metadata"] == meta
# -- list_messages --
class TestListMessages:
@pytest.mark.anyio
async def test_only_returns_message_category(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
await store.put(thread_id="t1", run_id="r1", event_type="run_start", category="lifecycle")
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["category"] == "message"
@pytest.mark.anyio
async def test_ascending_seq_order(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="first")
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="second")
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="third")
messages = await store.list_messages("t1")
seqs = [m["seq"] for m in messages]
assert seqs == sorted(seqs)
@pytest.mark.anyio
async def test_before_seq_pagination(self, store):
# Put 10 messages with seq 1..10
for i in range(10):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
messages = await store.list_messages("t1", before_seq=6, limit=3)
assert len(messages) == 3
assert [m["seq"] for m in messages] == [3, 4, 5]
@pytest.mark.anyio
async def test_after_seq_pagination(self, store):
for i in range(10):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
messages = await store.list_messages("t1", after_seq=7, limit=3)
assert len(messages) == 3
assert [m["seq"] for m in messages] == [8, 9, 10]
@pytest.mark.anyio
async def test_limit_restricts_count(self, store):
for _ in range(20):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
messages = await store.list_messages("t1", limit=5)
assert len(messages) == 5
@pytest.mark.anyio
async def test_cross_run_unified_ordering(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
await store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
await store.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message")
messages = await store.list_messages("t1")
assert [m["seq"] for m in messages] == [1, 2, 3, 4]
assert messages[0]["run_id"] == "r1"
assert messages[2]["run_id"] == "r2"
@pytest.mark.anyio
async def test_default_returns_latest(self, store):
for _ in range(10):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
messages = await store.list_messages("t1", limit=3)
assert [m["seq"] for m in messages] == [8, 9, 10]
# -- list_events --
class TestListEvents:
@pytest.mark.anyio
async def test_returns_all_categories_for_run(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
await store.put(thread_id="t1", run_id="r1", event_type="run_start", category="lifecycle")
events = await store.list_events("t1", "r1")
assert len(events) == 3
@pytest.mark.anyio
async def test_event_types_filter(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="llm_start", category="trace")
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
await store.put(thread_id="t1", run_id="r1", event_type="tool_start", category="trace")
events = await store.list_events("t1", "r1", event_types=["llm_end"])
assert len(events) == 1
assert events[0]["event_type"] == "llm_end"
@pytest.mark.anyio
async def test_only_returns_specified_run(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace")
events = await store.list_events("t1", "r1")
assert len(events) == 1
assert events[0]["run_id"] == "r1"
# -- list_messages_by_run --
class TestListMessagesByRun:
@pytest.mark.anyio
async def test_only_messages_for_specified_run(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
await store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
messages = await store.list_messages_by_run("t1", "r1")
assert len(messages) == 1
assert messages[0]["run_id"] == "r1"
assert messages[0]["category"] == "message"
# -- count_messages --
class TestCountMessages:
@pytest.mark.anyio
async def test_counts_only_message_category(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
await store.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace")
assert await store.count_messages("t1") == 2
# -- put_batch --
class TestPutBatch:
@pytest.mark.anyio
async def test_batch_assigns_seq(self, store):
events = [
{"thread_id": "t1", "run_id": "r1", "event_type": "human_message", "category": "message", "content": "a"},
{"thread_id": "t1", "run_id": "r1", "event_type": "ai_message", "category": "message", "content": "b"},
{"thread_id": "t1", "run_id": "r1", "event_type": "llm_end", "category": "trace"},
]
results = await store.put_batch(events)
assert len(results) == 3
assert all("seq" in r for r in results)
@pytest.mark.anyio
async def test_batch_seq_strictly_increasing(self, store):
events = [
{"thread_id": "t1", "run_id": "r1", "event_type": "human_message", "category": "message"},
{"thread_id": "t1", "run_id": "r1", "event_type": "ai_message", "category": "message"},
]
results = await store.put_batch(events)
assert results[0]["seq"] == 1
assert results[1]["seq"] == 2
# -- delete --
class TestDelete:
@pytest.mark.anyio
async def test_delete_by_thread(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message")
await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace")
count = await store.delete_by_thread("t1")
assert count == 3
assert await store.list_messages("t1") == []
assert await store.count_messages("t1") == 0
@pytest.mark.anyio
async def test_delete_by_run(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace")
count = await store.delete_by_run("t1", "r2")
assert count == 2
# r1 events should still be there
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["run_id"] == "r1"
@pytest.mark.anyio
async def test_delete_nonexistent_thread_returns_zero(self, store):
assert await store.delete_by_thread("nope") == 0
@pytest.mark.anyio
async def test_delete_nonexistent_run_returns_zero(self, store):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
assert await store.delete_by_run("t1", "nope") == 0
@pytest.mark.anyio
async def test_delete_nonexistent_thread_for_run_returns_zero(self, store):
assert await store.delete_by_run("nope", "r1") == 0
# -- Edge cases --
class TestEdgeCases:
@pytest.mark.anyio
async def test_empty_thread_list_messages(self, store):
assert await store.list_messages("empty") == []
@pytest.mark.anyio
async def test_empty_run_list_events(self, store):
assert await store.list_events("empty", "r1") == []
@pytest.mark.anyio
async def test_empty_thread_count_messages(self, store):
assert await store.count_messages("empty") == 0
+14
View File
@@ -613,6 +613,20 @@ checkpointer:
# backend: postgres # backend: postgres
# postgres_url: $DATABASE_URL # postgres_url: $DATABASE_URL
# ============================================================================
# Run Events Configuration
# ============================================================================
# Storage backend for run events (messages + execution traces).
#
# backend: memory -- No persistence, data lost on restart (default)
# backend: db -- SQL database via ORM, full query capability (production)
# backend: jsonl -- Append-only JSONL files (lightweight single-node persistence)
#
# run_events:
# backend: memory
# max_trace_content: 10240 # Truncation threshold for trace content (db backend, bytes)
# track_token_usage: true # Accumulate token counts to RunRow
# ============================================================================ # ============================================================================
# IM Channels Configuration # IM Channels Configuration
# ============================================================================ # ============================================================================