mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-13 10:55:59 +00:00
perf(runtime): index messages in MemoryRunEventStore to avoid O(n) scans (#3531)
list_messages re-scanned every event in the thread on each call (category filter + seq filter) — O(total events) per paginated request on the default run-events backend. Maintain a messages-only, seq-sorted projection of _events (shared dict refs, no copies) and locate the seq window with bisect: list_messages drops to O(log m + page) and count_messages to O(1). The index is kept in lockstep at every mutation site (put / put_batch via _put_one, delete_by_run, delete_by_thread). Externally observable behavior is unchanged — the full RunEventStore contract suite passes across memory/db/jsonl. Add a test covering pagination over non-contiguous message seqs (messages interleaved with trace events), including in-gap and exact-boundary cursors. Co-authored-by: ly-wang19 <ly-wang19@users.noreply.github.com> Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -6,6 +6,7 @@ since all mutations happen within the same event loop).
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import bisect
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from deerflow.runtime.events.store.base import RunEventStore
|
from deerflow.runtime.events.store.base import RunEventStore
|
||||||
@@ -13,7 +14,11 @@ from deerflow.runtime.events.store.base import RunEventStore
|
|||||||
|
|
||||||
class MemoryRunEventStore(RunEventStore):
|
class MemoryRunEventStore(RunEventStore):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._events: dict[str, list[dict]] = {} # thread_id -> sorted event list
|
self._events: dict[str, list[dict]] = {} # thread_id -> seq-sorted event list
|
||||||
|
# Messages-only projection of ``_events`` (same dict objects, no copies),
|
||||||
|
# kept in seq order so message pagination is O(log m + page) via bisect
|
||||||
|
# instead of re-scanning every event on each request.
|
||||||
|
self._messages: dict[str, list[dict]] = {} # thread_id -> seq-sorted message list
|
||||||
self._seq_counters: dict[str, int] = {} # thread_id -> last assigned seq
|
self._seq_counters: dict[str, int] = {} # thread_id -> last assigned seq
|
||||||
|
|
||||||
def _next_seq(self, thread_id: str) -> int:
|
def _next_seq(self, thread_id: str) -> int:
|
||||||
@@ -45,6 +50,8 @@ class MemoryRunEventStore(RunEventStore):
|
|||||||
"created_at": created_at or datetime.now(UTC).isoformat(),
|
"created_at": created_at or datetime.now(UTC).isoformat(),
|
||||||
}
|
}
|
||||||
self._events.setdefault(thread_id, []).append(record)
|
self._events.setdefault(thread_id, []).append(record)
|
||||||
|
if category == "message":
|
||||||
|
self._messages.setdefault(thread_id, []).append(record)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
async def put(
|
async def put(
|
||||||
@@ -76,18 +83,20 @@ class MemoryRunEventStore(RunEventStore):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
||||||
all_events = self._events.get(thread_id, [])
|
# ``messages`` is messages-only and seq-sorted, so the seq window is a
|
||||||
messages = [e for e in all_events if e["category"] == "message"]
|
# contiguous slice located with bisect (O(log m)) rather than a full scan.
|
||||||
|
messages = self._messages.get(thread_id, [])
|
||||||
|
|
||||||
if before_seq is not None:
|
if before_seq is not None:
|
||||||
messages = [e for e in messages if e["seq"] < before_seq]
|
# Records with seq < before_seq, then the last `limit` of them.
|
||||||
# Take the last `limit` records
|
hi = bisect.bisect_left(messages, before_seq, key=lambda e: e["seq"])
|
||||||
return messages[-limit:]
|
return messages[max(0, hi - limit) : hi]
|
||||||
elif after_seq is not None:
|
elif after_seq is not None:
|
||||||
messages = [e for e in messages if e["seq"] > after_seq]
|
# Records with seq > after_seq, then the first `limit` of them.
|
||||||
return messages[:limit]
|
lo = bisect.bisect_right(messages, after_seq, key=lambda e: e["seq"])
|
||||||
|
return messages[lo : lo + limit]
|
||||||
else:
|
else:
|
||||||
# Return the latest `limit` records, ascending
|
# Return the latest `limit` records, ascending.
|
||||||
return messages[-limit:]
|
return messages[-limit:]
|
||||||
|
|
||||||
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
||||||
@@ -110,11 +119,11 @@ class MemoryRunEventStore(RunEventStore):
|
|||||||
return filtered[-limit:] if len(filtered) > limit else filtered
|
return filtered[-limit:] if len(filtered) > limit else filtered
|
||||||
|
|
||||||
async def count_messages(self, thread_id):
|
async def count_messages(self, thread_id):
|
||||||
all_events = self._events.get(thread_id, [])
|
return len(self._messages.get(thread_id, []))
|
||||||
return sum(1 for e in all_events if e["category"] == "message")
|
|
||||||
|
|
||||||
async def delete_by_thread(self, thread_id):
|
async def delete_by_thread(self, thread_id):
|
||||||
events = self._events.pop(thread_id, [])
|
events = self._events.pop(thread_id, [])
|
||||||
|
self._messages.pop(thread_id, None)
|
||||||
self._seq_counters.pop(thread_id, None)
|
self._seq_counters.pop(thread_id, None)
|
||||||
return len(events)
|
return len(events)
|
||||||
|
|
||||||
@@ -125,4 +134,6 @@ class MemoryRunEventStore(RunEventStore):
|
|||||||
remaining = [e for e in all_events if e["run_id"] != run_id]
|
remaining = [e for e in all_events if e["run_id"] != run_id]
|
||||||
removed = len(all_events) - len(remaining)
|
removed = len(all_events) - len(remaining)
|
||||||
self._events[thread_id] = remaining
|
self._events[thread_id] = remaining
|
||||||
|
# Keep the message projection in lockstep (same surviving dict objects).
|
||||||
|
self._messages[thread_id] = [e for e in remaining if e["category"] == "message"]
|
||||||
return removed
|
return removed
|
||||||
|
|||||||
@@ -122,6 +122,26 @@ class TestListMessages:
|
|||||||
messages = await store.list_messages("t1", limit=3)
|
messages = await store.list_messages("t1", limit=3)
|
||||||
assert [m["seq"] for m in messages] == [8, 9, 10]
|
assert [m["seq"] for m in messages] == [8, 9, 10]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_pagination_with_interleaved_trace_events(self, store):
|
||||||
|
# Messages and non-message events interleave, so message seqs are
|
||||||
|
# non-contiguous (1, 3, 5, 7, 9). Seq-window pagination must still be
|
||||||
|
# correct over the messages-only projection, including when the cursor
|
||||||
|
# lands in a gap or exactly on a message seq (exclusive bound).
|
||||||
|
for i in range(10):
|
||||||
|
category = "message" if i % 2 == 0 else "trace"
|
||||||
|
await store.put(thread_id="t1", run_id="r1", event_type="e", category=category, content=str(i))
|
||||||
|
|
||||||
|
assert [m["seq"] for m in await store.list_messages("t1")] == [1, 3, 5, 7, 9]
|
||||||
|
# before_seq in a gap: seq < 6 -> [1, 3, 5], last 2
|
||||||
|
assert [m["seq"] for m in await store.list_messages("t1", before_seq=6, limit=2)] == [3, 5]
|
||||||
|
# before_seq on a message seq is exclusive: seq < 5 -> [1, 3]
|
||||||
|
assert [m["seq"] for m in await store.list_messages("t1", before_seq=5, limit=5)] == [1, 3]
|
||||||
|
# after_seq in a gap: seq > 4 -> [5, 7, 9], first 2
|
||||||
|
assert [m["seq"] for m in await store.list_messages("t1", after_seq=4, limit=2)] == [5, 7]
|
||||||
|
# after_seq on a message seq is exclusive: seq > 5 -> [7, 9]
|
||||||
|
assert [m["seq"] for m in await store.list_messages("t1", after_seq=5, limit=5)] == [7, 9]
|
||||||
|
|
||||||
|
|
||||||
# -- list_events --
|
# -- list_events --
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user