fix(runtime): make run creation persistence atomic (#3152)

* fix runtime run creation persistence atomicity

* fix run creation cancellation rollback

* fix run manager test cleanup await

* clarify run creation rollback on cancellation

* document new run persistence rollback boundary

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
rayhpeng
2026-05-23 22:43:34 +08:00
committed by GitHub
parent d0fa37e71d
commit 0fb05825a2
2 changed files with 172 additions and 11 deletions
@@ -181,6 +181,23 @@ class RunManager:
logger.warning("Failed to persist run %s to store", run_id, exc_info=True) logger.warning("Failed to persist run %s to store", run_id, exc_info=True)
return False return False
async def _persist_new_run_to_store(self, record: RunRecord) -> None:
"""Persist a newly created run record to the backing store.
Initial run creation is part of the run visibility boundary: callers
should not observe a run in memory unless its backing store row exists.
Unlike follow-up status/model updates, failures are propagated so the
caller can treat creation as failed. Rollback is the caller's
responsibility after inserting the record into ``_runs``.
"""
if self._store is None:
return
await self._call_store_with_retry(
"put",
record.run_id,
lambda: self._store.put(record.run_id, **self._store_put_payload(record)),
)
async def _persist_to_store(self, record: RunRecord, *, error: str | None = None) -> bool: async def _persist_to_store(self, record: RunRecord, *, error: str | None = None) -> bool:
"""Best-effort persist run record to backing store.""" """Best-effort persist run record to backing store."""
return await self._persist_snapshot_to_store( return await self._persist_snapshot_to_store(
@@ -321,7 +338,17 @@ class RunManager:
) )
async with self._lock: async with self._lock:
self._runs[run_id] = record self._runs[run_id] = record
await self._persist_to_store(record) persisted = False
try:
await self._persist_new_run_to_store(record)
persisted = True
except Exception:
logger.warning("Failed to persist run %s; rolled back in-memory record", run_id, exc_info=True)
raise
finally:
# Also covers cancellation, which bypasses ``except Exception``.
if not persisted:
self._runs.pop(run_id, None)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record return record
@@ -503,16 +530,8 @@ class RunManager:
raise ConflictError(f"Thread {thread_id} already has an active run") raise ConflictError(f"Thread {thread_id} already has an active run")
if multitask_strategy in ("interrupt", "rollback") and inflight: if multitask_strategy in ("interrupt", "rollback") and inflight:
for r in inflight:
r.abort_action = multitask_strategy
r.abort_event.set()
if r.task is not None and not r.task.done():
r.task.cancel()
r.status = RunStatus.interrupted
r.updated_at = now
interrupted_records.append(r)
logger.info( logger.info(
"Cancelled %d inflight run(s) on thread %s (strategy=%s)", "Preparing to cancel %d inflight run(s) on thread %s (strategy=%s)",
len(inflight), len(inflight),
thread_id, thread_id,
multitask_strategy, multitask_strategy,
@@ -532,10 +551,30 @@ class RunManager:
model_name=model_name, model_name=model_name,
) )
self._runs[run_id] = record self._runs[run_id] = record
persisted = False
try:
await self._persist_new_run_to_store(record)
persisted = True
except Exception:
logger.warning("Failed to persist run %s; rolled back in-memory record", run_id, exc_info=True)
raise
finally:
# Also covers cancellation, which bypasses ``except Exception``.
if not persisted:
self._runs.pop(run_id, None)
if multitask_strategy in ("interrupt", "rollback") and inflight:
for r in inflight:
r.abort_action = multitask_strategy
r.abort_event.set()
if r.task is not None and not r.task.done():
r.task.cancel()
r.status = RunStatus.interrupted
r.updated_at = now
interrupted_records.append(r)
for interrupted_record in interrupted_records: for interrupted_record in interrupted_records:
await self._persist_status(interrupted_record, RunStatus.interrupted) await self._persist_status(interrupted_record, RunStatus.interrupted)
await self._persist_to_store(record)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record return record
+122
View File
@@ -1,5 +1,6 @@
"""Tests for RunManager.""" """Tests for RunManager."""
import asyncio
import logging import logging
import re import re
import sqlite3 import sqlite3
@@ -471,6 +472,81 @@ async def test_create_record_is_not_store_only(manager: RunManager):
assert record.store_only is False assert record.store_only is False
@pytest.mark.anyio
async def test_create_rolls_back_in_memory_record_on_store_failure():
"""create() must fail and hide the run when the initial store write fails."""
from unittest.mock import AsyncMock
store = MemoryRunStore()
store.put = AsyncMock(side_effect=RuntimeError("db down"))
manager = RunManager(store=store)
with pytest.raises(RuntimeError, match="db down"):
await manager.create("thread-1")
assert manager._runs == {}
assert await manager.list_by_thread("thread-1") == []
@pytest.mark.anyio
async def test_create_rolls_back_in_memory_record_on_store_cancellation():
"""create() must also roll back when cancelled during the initial store write."""
store = MemoryRunStore()
async def cancelled_put(run_id, **kwargs):
raise asyncio.CancelledError
store.put = cancelled_put
manager = RunManager(store=store)
with pytest.raises(asyncio.CancelledError):
await manager.create("thread-1")
assert manager._runs == {}
assert await manager.list_by_thread("thread-1") == []
@pytest.mark.anyio
async def test_create_does_not_expose_run_until_store_persist_completes():
"""Concurrent readers must wait until the new run has been persisted."""
store = MemoryRunStore()
manager = RunManager(store=store)
original_put = store.put
put_started = asyncio.Event()
allow_put = asyncio.Event()
async def blocking_put(run_id, **kwargs):
put_started.set()
await allow_put.wait()
return await original_put(run_id, **kwargs)
store.put = blocking_put
create_task = asyncio.create_task(manager.create("thread-1"))
list_task = None
try:
await put_started.wait()
list_task = asyncio.create_task(manager.list_by_thread("thread-1"))
await asyncio.sleep(0)
assert not list_task.done()
allow_put.set()
record = await create_task
runs = await list_task
assert [run.run_id for run in runs] == [record.run_id]
finally:
allow_put.set()
cleanup_tasks = []
for task in (list_task, create_task):
if task is None:
continue
if not task.done():
task.cancel()
cleanup_tasks.append(task)
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_prefers_in_memory_record_over_store(): async def test_get_prefers_in_memory_record_over_store():
"""In-memory records retain task/control state when store has same run.""" """In-memory records retain task/control state when store has same run."""
@@ -558,6 +634,52 @@ async def test_create_or_reject_interrupt_persists_interrupted_status_to_store()
assert stored_old["status"] == "interrupted" assert stored_old["status"] == "interrupted"
@pytest.mark.anyio
async def test_create_or_reject_does_not_interrupt_old_run_when_new_run_store_write_fails():
"""A failed new-run persist must not cancel the existing inflight run."""
from unittest.mock import AsyncMock
store = MemoryRunStore()
manager = RunManager(store=store)
old = await manager.create("thread-1")
await manager.set_status(old.run_id, RunStatus.running)
store.put = AsyncMock(side_effect=RuntimeError("db down"))
with pytest.raises(RuntimeError, match="db down"):
await manager.create_or_reject("thread-1", multitask_strategy="interrupt")
stored_old = await store.get(old.run_id)
assert list(manager._runs) == [old.run_id]
assert old.status == RunStatus.running
assert old.abort_event.is_set() is False
assert stored_old is not None
assert stored_old["status"] == "running"
@pytest.mark.anyio
async def test_create_or_reject_does_not_interrupt_old_run_when_new_run_store_write_is_cancelled():
"""Cancellation during new-run persist must not cancel the existing run."""
store = MemoryRunStore()
manager = RunManager(store=store)
old = await manager.create("thread-1")
await manager.set_status(old.run_id, RunStatus.running)
async def cancelled_put(run_id, **kwargs):
raise asyncio.CancelledError
store.put = cancelled_put
with pytest.raises(asyncio.CancelledError):
await manager.create_or_reject("thread-1", multitask_strategy="interrupt")
stored_old = await store.get(old.run_id)
assert list(manager._runs) == [old.run_id]
assert old.status == RunStatus.running
assert old.abort_event.is_set() is False
assert stored_old is not None
assert stored_old["status"] == "running"
@pytest.mark.anyio @pytest.mark.anyio
async def test_create_or_reject_rollback_persists_interrupted_status_to_store(): async def test_create_or_reject_rollback_persists_interrupted_status_to_store():
"""rollback strategy should persist interrupted status for old runs.""" """rollback strategy should persist interrupted status for old runs."""