mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-13 10:55:59 +00:00
perf(runtime): index runs by thread_id to avoid O(n) scans in RunManager (#3499)
* perf(runtime): index runs by thread_id to avoid O(n) scans in RunManager RunManager.list_by_thread, create_or_reject (inflight check), and has_inflight each filtered every in-memory run by thread_id — an O(total in-memory runs) scan that grows with overall gateway traffic rather than the queried thread's depth. Add a thread_id -> run_ids secondary index (an insertion-ordered dict used as an ordered set) maintained in lockstep with _runs under the existing lock at every add/remove site (create, create_or_reject, both rollbacks, cleanup). The three per-thread queries now run in O(runs-in-thread); insertion order is preserved so list_by_thread keeps stable tie-breaking. Behavior unchanged. Adds 6 regression tests; full RunManager suite 146 passed. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * test(runtime): cover create_or_reject rollback + clarify thread-index guard docstrings Address review on #3499 (fancyboi999): - Reword _thread_records_locked docstring: lockstep under self._lock is the correctness guarantee; self._runs.get is one-directional defense-in-depth (drops stale ids, cannot recover index-missing ids), not reconciliation. - Add test_failed_create_or_reject_unindexes_run covering the create_or_reject rollback/unindex mutation site (the last untested mutation path). - Fix _FailingPutRunStore docstring ("initial put" -> "every put"). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> --------- Co-authored-by: ly-wang19 <ly-wang19@users.noreply.github.com> Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -119,10 +119,45 @@ class RunManager:
|
|||||||
persistence_retry_policy: PersistenceRetryPolicy | None = None,
|
persistence_retry_policy: PersistenceRetryPolicy | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._runs: dict[str, RunRecord] = {}
|
self._runs: dict[str, RunRecord] = {}
|
||||||
|
# Secondary index: thread_id -> insertion-ordered run_id set (a dict is
|
||||||
|
# used as an ordered set), maintained in lockstep with ``_runs`` so
|
||||||
|
# per-thread queries avoid O(total in-memory runs) full scans while
|
||||||
|
# preserving ``_runs`` iteration order (see ``_thread_records_locked``).
|
||||||
|
self._runs_by_thread: dict[str, dict[str, None]] = {}
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
self._store = store
|
self._store = store
|
||||||
self._persistence_retry_policy = persistence_retry_policy or PersistenceRetryPolicy()
|
self._persistence_retry_policy = persistence_retry_policy or PersistenceRetryPolicy()
|
||||||
|
|
||||||
|
def _index_run_locked(self, record: RunRecord) -> None:
|
||||||
|
"""Register *record* in the thread index. Caller must hold ``self._lock``."""
|
||||||
|
self._runs_by_thread.setdefault(record.thread_id, {})[record.run_id] = None
|
||||||
|
|
||||||
|
def _unindex_run_locked(self, run_id: str, thread_id: str) -> None:
|
||||||
|
"""Drop *run_id* from the thread index. Caller must hold ``self._lock``."""
|
||||||
|
bucket = self._runs_by_thread.get(thread_id)
|
||||||
|
if bucket is not None:
|
||||||
|
bucket.pop(run_id, None)
|
||||||
|
if not bucket:
|
||||||
|
self._runs_by_thread.pop(thread_id, None)
|
||||||
|
|
||||||
|
def _thread_records_locked(self, thread_id: str) -> list[RunRecord]:
|
||||||
|
"""Return live in-memory records for *thread_id*. Caller must hold ``self._lock``.
|
||||||
|
|
||||||
|
Uses the ``_runs_by_thread`` index for O(runs-in-thread) lookup instead of
|
||||||
|
scanning every in-memory run. Correctness rests on the index and ``_runs``
|
||||||
|
being mutated in lockstep under ``self._lock`` (no ``await`` between the two
|
||||||
|
writes), so any holder of the lock sees them agree. The ``self._runs.get``
|
||||||
|
filter is defense-in-depth, not reconciliation: it drops a stale id still in
|
||||||
|
the index but already gone from ``_runs``, yet it cannot recover a run that is
|
||||||
|
in ``_runs`` but missing from the index (such a run would be silently
|
||||||
|
omitted). It guards only that one direction, should a future refactor ever
|
||||||
|
break the lockstep invariant.
|
||||||
|
"""
|
||||||
|
run_ids = self._runs_by_thread.get(thread_id)
|
||||||
|
if not run_ids:
|
||||||
|
return []
|
||||||
|
return [record for run_id in run_ids if (record := self._runs.get(run_id)) is not None]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]:
|
def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]:
|
||||||
payload = {
|
payload = {
|
||||||
@@ -345,6 +380,7 @@ class RunManager:
|
|||||||
)
|
)
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._runs[run_id] = record
|
self._runs[run_id] = record
|
||||||
|
self._index_run_locked(record)
|
||||||
persisted = False
|
persisted = False
|
||||||
try:
|
try:
|
||||||
await self._persist_new_run_to_store(record)
|
await self._persist_new_run_to_store(record)
|
||||||
@@ -356,6 +392,7 @@ class RunManager:
|
|||||||
# Also covers cancellation, which bypasses ``except Exception``.
|
# Also covers cancellation, which bypasses ``except Exception``.
|
||||||
if not persisted:
|
if not persisted:
|
||||||
self._runs.pop(run_id, None)
|
self._runs.pop(run_id, None)
|
||||||
|
self._unindex_run_locked(run_id, record.thread_id)
|
||||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
@@ -411,8 +448,7 @@ class RunManager:
|
|||||||
limit: Maximum number of runs to return.
|
limit: Maximum number of runs to return.
|
||||||
"""
|
"""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
# Dict insertion order gives deterministic results when timestamps tie.
|
memory_records = self._thread_records_locked(thread_id)
|
||||||
memory_records = [r for r in self._runs.values() if r.thread_id == thread_id]
|
|
||||||
if self._store is None:
|
if self._store is None:
|
||||||
return sorted(memory_records, key=lambda r: r.created_at, reverse=True)[:limit]
|
return sorted(memory_records, key=lambda r: r.created_at, reverse=True)[:limit]
|
||||||
records_by_id = {record.run_id: record for record in memory_records}
|
records_by_id = {record.run_id: record for record in memory_records}
|
||||||
@@ -532,7 +568,7 @@ class RunManager:
|
|||||||
if multitask_strategy not in _supported_strategies:
|
if multitask_strategy not in _supported_strategies:
|
||||||
raise UnsupportedStrategyError(f"Multitask strategy '{multitask_strategy}' is not yet supported. Supported strategies: {', '.join(_supported_strategies)}")
|
raise UnsupportedStrategyError(f"Multitask strategy '{multitask_strategy}' is not yet supported. Supported strategies: {', '.join(_supported_strategies)}")
|
||||||
|
|
||||||
inflight = [r for r in self._runs.values() if r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running)]
|
inflight = [r for r in self._thread_records_locked(thread_id) if r.status in (RunStatus.pending, RunStatus.running)]
|
||||||
|
|
||||||
if multitask_strategy == "reject" and inflight:
|
if multitask_strategy == "reject" and inflight:
|
||||||
raise ConflictError(f"Thread {thread_id} already has an active run")
|
raise ConflictError(f"Thread {thread_id} already has an active run")
|
||||||
@@ -560,6 +596,7 @@ class RunManager:
|
|||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
)
|
)
|
||||||
self._runs[run_id] = record
|
self._runs[run_id] = record
|
||||||
|
self._index_run_locked(record)
|
||||||
persisted = False
|
persisted = False
|
||||||
try:
|
try:
|
||||||
await self._persist_new_run_to_store(record)
|
await self._persist_new_run_to_store(record)
|
||||||
@@ -571,6 +608,7 @@ class RunManager:
|
|||||||
# Also covers cancellation, which bypasses ``except Exception``.
|
# Also covers cancellation, which bypasses ``except Exception``.
|
||||||
if not persisted:
|
if not persisted:
|
||||||
self._runs.pop(run_id, None)
|
self._runs.pop(run_id, None)
|
||||||
|
self._unindex_run_locked(run_id, record.thread_id)
|
||||||
|
|
||||||
if multitask_strategy in ("interrupt", "rollback") and inflight:
|
if multitask_strategy in ("interrupt", "rollback") and inflight:
|
||||||
for r in inflight:
|
for r in inflight:
|
||||||
@@ -644,14 +682,16 @@ class RunManager:
|
|||||||
async def has_inflight(self, thread_id: str) -> bool:
|
async def has_inflight(self, thread_id: str) -> bool:
|
||||||
"""Return ``True`` if *thread_id* has a pending or running run."""
|
"""Return ``True`` if *thread_id* has a pending or running run."""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
return any(r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running) for r in self._runs.values())
|
return any(r.status in (RunStatus.pending, RunStatus.running) for r in self._thread_records_locked(thread_id))
|
||||||
|
|
||||||
async def cleanup(self, run_id: str, *, delay: float = 300) -> None:
|
async def cleanup(self, run_id: str, *, delay: float = 300) -> None:
|
||||||
"""Remove a run record after an optional delay."""
|
"""Remove a run record after an optional delay."""
|
||||||
if delay > 0:
|
if delay > 0:
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._runs.pop(run_id, None)
|
record = self._runs.pop(run_id, None)
|
||||||
|
if record is not None:
|
||||||
|
self._unindex_run_locked(run_id, record.thread_id)
|
||||||
logger.debug("Run record %s cleaned up", run_id)
|
logger.debug("Run record %s cleaned up", run_id)
|
||||||
|
|
||||||
async def shutdown(self, *, timeout: float = 5.0) -> None:
|
async def shutdown(self, *, timeout: float = 5.0) -> None:
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import pytest
|
|||||||
from sqlalchemy.exc import DatabaseError as SQLAlchemyDatabaseError
|
from sqlalchemy.exc import DatabaseError as SQLAlchemyDatabaseError
|
||||||
|
|
||||||
from deerflow.runtime import DisconnectMode, RunManager, RunStatus
|
from deerflow.runtime import DisconnectMode, RunManager, RunStatus
|
||||||
from deerflow.runtime.runs.manager import PersistenceRetryPolicy
|
from deerflow.runtime.runs.manager import ConflictError, PersistenceRetryPolicy
|
||||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||||
|
|
||||||
ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
|
ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
|
||||||
@@ -871,3 +871,100 @@ async def test_list_by_thread_falls_back_to_store_with_user_filter():
|
|||||||
|
|
||||||
runs = await mgr.list_by_thread("thread-1", user_id="user-1")
|
runs = await mgr.list_by_thread("thread-1", user_id="user-1")
|
||||||
assert [r.run_id for r in runs] == ["run-1"]
|
assert [r.run_id for r in runs] == ["run-1"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Per-thread index (thread_id -> run_ids): keeps per-thread queries
|
||||||
|
# O(runs-in-thread) instead of scanning every in-memory run, and stays
|
||||||
|
# consistent with ``_runs`` across create / cleanup / rollback.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _FailingPutRunStore(MemoryRunStore):
|
||||||
|
"""Memory run store whose every ``put`` fails (non-retryably)."""
|
||||||
|
|
||||||
|
async def put(self, run_id, **kwargs):
|
||||||
|
raise ValueError("simulated persist failure")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_thread_index_scopes_runs_per_thread(manager: RunManager):
|
||||||
|
a1 = await manager.create("thread-a")
|
||||||
|
a2 = await manager.create("thread-a")
|
||||||
|
b1 = await manager.create("thread-b")
|
||||||
|
|
||||||
|
# The index mirrors _runs membership, bucketed by thread.
|
||||||
|
assert set(manager._runs_by_thread["thread-a"]) == {a1.run_id, a2.run_id}
|
||||||
|
assert set(manager._runs_by_thread["thread-b"]) == {b1.run_id}
|
||||||
|
|
||||||
|
# Per-thread queries return only that thread's runs (no cross-thread leak).
|
||||||
|
assert {r.run_id for r in await manager.list_by_thread("thread-a")} == {a1.run_id, a2.run_id}
|
||||||
|
assert {r.run_id for r in await manager.list_by_thread("thread-b")} == {b1.run_id}
|
||||||
|
assert await manager.list_by_thread("thread-missing") == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_thread_index_preserves_insertion_order(manager: RunManager):
|
||||||
|
# The index is insertion-ordered (dict-as-ordered-set) so list_by_thread
|
||||||
|
# keeps the stable tie-breaking the full-scan implementation guaranteed.
|
||||||
|
first = await manager.create("thread-a")
|
||||||
|
second = await manager.create("thread-a")
|
||||||
|
assert list(manager._runs_by_thread["thread-a"]) == [first.run_id, second.run_id]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_thread_index_cleanup_prunes_run_and_empty_bucket(manager: RunManager):
|
||||||
|
a1 = await manager.create("thread-a")
|
||||||
|
a2 = await manager.create("thread-a")
|
||||||
|
|
||||||
|
await manager.cleanup(a1.run_id, delay=0)
|
||||||
|
assert a1.run_id not in manager._runs
|
||||||
|
assert set(manager._runs_by_thread["thread-a"]) == {a2.run_id}
|
||||||
|
|
||||||
|
await manager.cleanup(a2.run_id, delay=0)
|
||||||
|
# Empty buckets are pruned so the index cannot grow without bound.
|
||||||
|
assert "thread-a" not in manager._runs_by_thread
|
||||||
|
assert await manager.list_by_thread("thread-a") == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_has_inflight_reflects_index(manager: RunManager):
|
||||||
|
record = await manager.create("thread-a")
|
||||||
|
assert await manager.has_inflight("thread-a") is True
|
||||||
|
assert await manager.has_inflight("thread-b") is False
|
||||||
|
|
||||||
|
await manager.set_status(record.run_id, RunStatus.success)
|
||||||
|
assert await manager.has_inflight("thread-a") is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_or_reject_inflight_is_thread_scoped(manager: RunManager):
|
||||||
|
await manager.create_or_reject("thread-a", multitask_strategy="reject")
|
||||||
|
# A different thread is unaffected by thread-a's active run.
|
||||||
|
await manager.create_or_reject("thread-b", multitask_strategy="reject")
|
||||||
|
# A second active run on the same thread is rejected.
|
||||||
|
with pytest.raises(ConflictError):
|
||||||
|
await manager.create_or_reject("thread-a", multitask_strategy="reject")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_failed_create_unindexes_run():
|
||||||
|
manager = RunManager(store=_FailingPutRunStore())
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await manager.create("thread-a")
|
||||||
|
# A rolled-back run must leave no trace in either _runs or the index.
|
||||||
|
assert manager._runs == {}
|
||||||
|
assert "thread-a" not in manager._runs_by_thread
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_failed_create_or_reject_unindexes_run():
|
||||||
|
# Symmetric to test_failed_create_unindexes_run: create_or_reject has its own
|
||||||
|
# insert + rollback-unindex site, so a persist failure there must also leave
|
||||||
|
# neither _runs nor the index holding the rolled-back run. This closes the last
|
||||||
|
# mutation path not exercised by an index-consistency test.
|
||||||
|
manager = RunManager(store=_FailingPutRunStore())
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await manager.create_or_reject("thread-a", multitask_strategy="reject")
|
||||||
|
assert manager._runs == {}
|
||||||
|
assert "thread-a" not in manager._runs_by_thread
|
||||||
|
|||||||
Reference in New Issue
Block a user