perf(runtime): index messages in MemoryRunEventStore to avoid O(n) scans (#3531)

list_messages re-scanned every event in the thread on each call (category
filter + seq filter) — O(total events) per paginated request on the default
run-events backend. Maintain a messages-only, seq-sorted projection of _events
(shared dict refs, no copies) and locate the seq window with bisect:
list_messages drops to O(log m + page) and count_messages to O(1). The index is
kept in lockstep at every mutation site (put / put_batch via _put_one,
delete_by_run, delete_by_thread).

Externally observable behavior is unchanged — the full RunEventStore contract
suite passes across memory/db/jsonl.

Add a test covering pagination over non-contiguous message seqs (messages
interleaved with trace events), including in-gap and exact-boundary cursors.

Co-authored-by: ly-wang19 <ly-wang19@users.noreply.github.com>
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
ly-wang19
2026-06-12 22:58:30 +08:00
committed by GitHub
parent c002596ab4
commit 579e416459
2 changed files with 42 additions and 11 deletions
@@ -6,6 +6,7 @@ since all mutations happen within the same event loop).
from __future__ import annotations
import bisect
from datetime import UTC, datetime
from deerflow.runtime.events.store.base import RunEventStore
@@ -13,7 +14,11 @@ from deerflow.runtime.events.store.base import RunEventStore
class MemoryRunEventStore(RunEventStore):
def __init__(self) -> None:
self._events: dict[str, list[dict]] = {} # thread_id -> sorted event list
self._events: dict[str, list[dict]] = {} # thread_id -> seq-sorted event list
# Messages-only projection of ``_events`` (same dict objects, no copies),
# kept in seq order so message pagination is O(log m + page) via bisect
# instead of re-scanning every event on each request.
self._messages: dict[str, list[dict]] = {} # thread_id -> seq-sorted message list
self._seq_counters: dict[str, int] = {} # thread_id -> last assigned seq
def _next_seq(self, thread_id: str) -> int:
@@ -45,6 +50,8 @@ class MemoryRunEventStore(RunEventStore):
"created_at": created_at or datetime.now(UTC).isoformat(),
}
self._events.setdefault(thread_id, []).append(record)
if category == "message":
self._messages.setdefault(thread_id, []).append(record)
return record
async def put(
@@ -76,18 +83,20 @@ class MemoryRunEventStore(RunEventStore):
return results
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
all_events = self._events.get(thread_id, [])
messages = [e for e in all_events if e["category"] == "message"]
# ``messages`` is messages-only and seq-sorted, so the seq window is a
# contiguous slice located with bisect (O(log m)) rather than a full scan.
messages = self._messages.get(thread_id, [])
if before_seq is not None:
messages = [e for e in messages if e["seq"] < before_seq]
# Take the last `limit` records
return messages[-limit:]
# Records with seq < before_seq, then the last `limit` of them.
hi = bisect.bisect_left(messages, before_seq, key=lambda e: e["seq"])
return messages[max(0, hi - limit) : hi]
elif after_seq is not None:
messages = [e for e in messages if e["seq"] > after_seq]
return messages[:limit]
# Records with seq > after_seq, then the first `limit` of them.
lo = bisect.bisect_right(messages, after_seq, key=lambda e: e["seq"])
return messages[lo : lo + limit]
else:
# Return the latest `limit` records, ascending
# Return the latest `limit` records, ascending.
return messages[-limit:]
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
@@ -110,11 +119,11 @@ class MemoryRunEventStore(RunEventStore):
return filtered[-limit:] if len(filtered) > limit else filtered
async def count_messages(self, thread_id):
all_events = self._events.get(thread_id, [])
return sum(1 for e in all_events if e["category"] == "message")
return len(self._messages.get(thread_id, []))
async def delete_by_thread(self, thread_id):
events = self._events.pop(thread_id, [])
self._messages.pop(thread_id, None)
self._seq_counters.pop(thread_id, None)
return len(events)
@@ -125,4 +134,6 @@ class MemoryRunEventStore(RunEventStore):
remaining = [e for e in all_events if e["run_id"] != run_id]
removed = len(all_events) - len(remaining)
self._events[thread_id] = remaining
# Keep the message projection in lockstep (same surviving dict objects).
self._messages[thread_id] = [e for e in remaining if e["category"] == "message"]
return removed