diff --git a/backend/packages/harness/deerflow/runtime/events/store/memory.py b/backend/packages/harness/deerflow/runtime/events/store/memory.py index cf70e1cdf..a2bb54819 100644 --- a/backend/packages/harness/deerflow/runtime/events/store/memory.py +++ b/backend/packages/harness/deerflow/runtime/events/store/memory.py @@ -6,6 +6,7 @@ since all mutations happen within the same event loop). from __future__ import annotations +import bisect from datetime import UTC, datetime from deerflow.runtime.events.store.base import RunEventStore @@ -13,7 +14,11 @@ from deerflow.runtime.events.store.base import RunEventStore class MemoryRunEventStore(RunEventStore): def __init__(self) -> None: - self._events: dict[str, list[dict]] = {} # thread_id -> sorted event list + self._events: dict[str, list[dict]] = {} # thread_id -> seq-sorted event list + # Messages-only projection of ``_events`` (same dict objects, no copies), + # kept in seq order so message pagination is O(log m + page) via bisect + # instead of re-scanning every event on each request. + self._messages: dict[str, list[dict]] = {} # thread_id -> seq-sorted message list self._seq_counters: dict[str, int] = {} # thread_id -> last assigned seq def _next_seq(self, thread_id: str) -> int: @@ -45,6 +50,8 @@ class MemoryRunEventStore(RunEventStore): "created_at": created_at or datetime.now(UTC).isoformat(), } self._events.setdefault(thread_id, []).append(record) + if category == "message": + self._messages.setdefault(thread_id, []).append(record) return record async def put( @@ -76,18 +83,20 @@ class MemoryRunEventStore(RunEventStore): return results async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None): - all_events = self._events.get(thread_id, []) - messages = [e for e in all_events if e["category"] == "message"] + # ``messages`` is messages-only and seq-sorted, so the seq window is a + # contiguous slice located with bisect (O(log m)) rather than a full scan. + messages = self._messages.get(thread_id, []) if before_seq is not None: - messages = [e for e in messages if e["seq"] < before_seq] - # Take the last `limit` records - return messages[-limit:] + # Records with seq < before_seq, then the last `limit` of them. + hi = bisect.bisect_left(messages, before_seq, key=lambda e: e["seq"]) + return messages[max(0, hi - limit) : hi] elif after_seq is not None: - messages = [e for e in messages if e["seq"] > after_seq] - return messages[:limit] + # Records with seq > after_seq, then the first `limit` of them. + lo = bisect.bisect_right(messages, after_seq, key=lambda e: e["seq"]) + return messages[lo : lo + limit] else: - # Return the latest `limit` records, ascending + # Return the latest `limit` records, ascending. return messages[-limit:] async def list_events(self, thread_id, run_id, *, event_types=None, limit=500): @@ -110,11 +119,11 @@ class MemoryRunEventStore(RunEventStore): return filtered[-limit:] if len(filtered) > limit else filtered async def count_messages(self, thread_id): - all_events = self._events.get(thread_id, []) - return sum(1 for e in all_events if e["category"] == "message") + return len(self._messages.get(thread_id, [])) async def delete_by_thread(self, thread_id): events = self._events.pop(thread_id, []) + self._messages.pop(thread_id, None) self._seq_counters.pop(thread_id, None) return len(events) @@ -125,4 +134,6 @@ class MemoryRunEventStore(RunEventStore): remaining = [e for e in all_events if e["run_id"] != run_id] removed = len(all_events) - len(remaining) self._events[thread_id] = remaining + # Keep the message projection in lockstep (same surviving dict objects). + self._messages[thread_id] = [e for e in remaining if e["category"] == "message"] return removed diff --git a/backend/tests/test_run_event_store.py b/backend/tests/test_run_event_store.py index 17b796af7..7c409c711 100644 --- a/backend/tests/test_run_event_store.py +++ b/backend/tests/test_run_event_store.py @@ -122,6 +122,26 @@ class TestListMessages: messages = await store.list_messages("t1", limit=3) assert [m["seq"] for m in messages] == [8, 9, 10] + @pytest.mark.anyio + async def test_pagination_with_interleaved_trace_events(self, store): + # Messages and non-message events interleave, so message seqs are + # non-contiguous (1, 3, 5, 7, 9). Seq-window pagination must still be + # correct over the messages-only projection, including when the cursor + # lands in a gap or exactly on a message seq (exclusive bound). + for i in range(10): + category = "message" if i % 2 == 0 else "trace" + await store.put(thread_id="t1", run_id="r1", event_type="e", category=category, content=str(i)) + + assert [m["seq"] for m in await store.list_messages("t1")] == [1, 3, 5, 7, 9] + # before_seq in a gap: seq < 6 -> [1, 3, 5], last 2 + assert [m["seq"] for m in await store.list_messages("t1", before_seq=6, limit=2)] == [3, 5] + # before_seq on a message seq is exclusive: seq < 5 -> [1, 3] + assert [m["seq"] for m in await store.list_messages("t1", before_seq=5, limit=5)] == [1, 3] + # after_seq in a gap: seq > 4 -> [5, 7, 9], first 2 + assert [m["seq"] for m in await store.list_messages("t1", after_seq=4, limit=2)] == [5, 7] + # after_seq on a message seq is exclusive: seq > 5 -> [7, 9] + assert [m["seq"] for m in await store.list_messages("t1", after_seq=5, limit=5)] == [7, 9] + # -- list_events --