perf(runtime): index messages in MemoryRunEventStore to avoid O(n) scans (#3531)

list_messages re-scanned every event in the thread on each call (category
filter + seq filter) — O(total events) per paginated request on the default
run-events backend. Maintain a messages-only, seq-sorted projection of _events
(shared dict refs, no copies) and locate the seq window with bisect:
list_messages drops to O(log m + page) and count_messages to O(1). The index is
kept in lockstep at every mutation site (put / put_batch via _put_one,
delete_by_run, delete_by_thread).

Externally observable behavior is unchanged — the full RunEventStore contract
suite passes across memory/db/jsonl.

Add a test covering pagination over non-contiguous message seqs (messages
interleaved with trace events), including in-gap and exact-boundary cursors.

Co-authored-by: ly-wang19 <ly-wang19@users.noreply.github.com>
Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
ly-wang19
2026-06-12 22:58:30 +08:00
committed by GitHub
parent c002596ab4
commit 579e416459
2 changed files with 42 additions and 11 deletions
@@ -6,6 +6,7 @@ since all mutations happen within the same event loop).
from __future__ import annotations from __future__ import annotations
import bisect
from datetime import UTC, datetime from datetime import UTC, datetime
from deerflow.runtime.events.store.base import RunEventStore from deerflow.runtime.events.store.base import RunEventStore
@@ -13,7 +14,11 @@ from deerflow.runtime.events.store.base import RunEventStore
class MemoryRunEventStore(RunEventStore): class MemoryRunEventStore(RunEventStore):
def __init__(self) -> None: def __init__(self) -> None:
self._events: dict[str, list[dict]] = {} # thread_id -> sorted event list self._events: dict[str, list[dict]] = {} # thread_id -> seq-sorted event list
# Messages-only projection of ``_events`` (same dict objects, no copies),
# kept in seq order so message pagination is O(log m + page) via bisect
# instead of re-scanning every event on each request.
self._messages: dict[str, list[dict]] = {} # thread_id -> seq-sorted message list
self._seq_counters: dict[str, int] = {} # thread_id -> last assigned seq self._seq_counters: dict[str, int] = {} # thread_id -> last assigned seq
def _next_seq(self, thread_id: str) -> int: def _next_seq(self, thread_id: str) -> int:
@@ -45,6 +50,8 @@ class MemoryRunEventStore(RunEventStore):
"created_at": created_at or datetime.now(UTC).isoformat(), "created_at": created_at or datetime.now(UTC).isoformat(),
} }
self._events.setdefault(thread_id, []).append(record) self._events.setdefault(thread_id, []).append(record)
if category == "message":
self._messages.setdefault(thread_id, []).append(record)
return record return record
async def put( async def put(
@@ -76,18 +83,20 @@ class MemoryRunEventStore(RunEventStore):
return results return results
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None): async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
all_events = self._events.get(thread_id, []) # ``messages`` is messages-only and seq-sorted, so the seq window is a
messages = [e for e in all_events if e["category"] == "message"] # contiguous slice located with bisect (O(log m)) rather than a full scan.
messages = self._messages.get(thread_id, [])
if before_seq is not None: if before_seq is not None:
messages = [e for e in messages if e["seq"] < before_seq] # Records with seq < before_seq, then the last `limit` of them.
# Take the last `limit` records hi = bisect.bisect_left(messages, before_seq, key=lambda e: e["seq"])
return messages[-limit:] return messages[max(0, hi - limit) : hi]
elif after_seq is not None: elif after_seq is not None:
messages = [e for e in messages if e["seq"] > after_seq] # Records with seq > after_seq, then the first `limit` of them.
return messages[:limit] lo = bisect.bisect_right(messages, after_seq, key=lambda e: e["seq"])
return messages[lo : lo + limit]
else: else:
# Return the latest `limit` records, ascending # Return the latest `limit` records, ascending.
return messages[-limit:] return messages[-limit:]
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500): async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
@@ -110,11 +119,11 @@ class MemoryRunEventStore(RunEventStore):
return filtered[-limit:] if len(filtered) > limit else filtered return filtered[-limit:] if len(filtered) > limit else filtered
async def count_messages(self, thread_id): async def count_messages(self, thread_id):
all_events = self._events.get(thread_id, []) return len(self._messages.get(thread_id, []))
return sum(1 for e in all_events if e["category"] == "message")
async def delete_by_thread(self, thread_id): async def delete_by_thread(self, thread_id):
events = self._events.pop(thread_id, []) events = self._events.pop(thread_id, [])
self._messages.pop(thread_id, None)
self._seq_counters.pop(thread_id, None) self._seq_counters.pop(thread_id, None)
return len(events) return len(events)
@@ -125,4 +134,6 @@ class MemoryRunEventStore(RunEventStore):
remaining = [e for e in all_events if e["run_id"] != run_id] remaining = [e for e in all_events if e["run_id"] != run_id]
removed = len(all_events) - len(remaining) removed = len(all_events) - len(remaining)
self._events[thread_id] = remaining self._events[thread_id] = remaining
# Keep the message projection in lockstep (same surviving dict objects).
self._messages[thread_id] = [e for e in remaining if e["category"] == "message"]
return removed return removed
+20
View File
@@ -122,6 +122,26 @@ class TestListMessages:
messages = await store.list_messages("t1", limit=3) messages = await store.list_messages("t1", limit=3)
assert [m["seq"] for m in messages] == [8, 9, 10] assert [m["seq"] for m in messages] == [8, 9, 10]
@pytest.mark.anyio
async def test_pagination_with_interleaved_trace_events(self, store):
# Messages and non-message events interleave, so message seqs are
# non-contiguous (1, 3, 5, 7, 9). Seq-window pagination must still be
# correct over the messages-only projection, including when the cursor
# lands in a gap or exactly on a message seq (exclusive bound).
for i in range(10):
category = "message" if i % 2 == 0 else "trace"
await store.put(thread_id="t1", run_id="r1", event_type="e", category=category, content=str(i))
assert [m["seq"] for m in await store.list_messages("t1")] == [1, 3, 5, 7, 9]
# before_seq in a gap: seq < 6 -> [1, 3, 5], last 2
assert [m["seq"] for m in await store.list_messages("t1", before_seq=6, limit=2)] == [3, 5]
# before_seq on a message seq is exclusive: seq < 5 -> [1, 3]
assert [m["seq"] for m in await store.list_messages("t1", before_seq=5, limit=5)] == [1, 3]
# after_seq in a gap: seq > 4 -> [5, 7, 9], first 2
assert [m["seq"] for m in await store.list_messages("t1", after_seq=4, limit=2)] == [5, 7]
# after_seq on a message seq is exclusive: seq > 5 -> [7, 9]
assert [m["seq"] for m in await store.list_messages("t1", after_seq=5, limit=5)] == [7, 9]
# -- list_events -- # -- list_events --