mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 08:55:59 +00:00
fix(runtime): make run creation persistence atomic (#3152)
* fix runtime run creation persistence atomicity * fix run creation cancellation rollback * fix run manager test cleanup await * clarify run creation rollback on cancellation * document new run persistence rollback boundary --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Tests for RunManager."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import sqlite3
|
||||
@@ -471,6 +472,81 @@ async def test_create_record_is_not_store_only(manager: RunManager):
|
||||
assert record.store_only is False
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_rolls_back_in_memory_record_on_store_failure():
|
||||
"""create() must fail and hide the run when the initial store write fails."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
store = MemoryRunStore()
|
||||
store.put = AsyncMock(side_effect=RuntimeError("db down"))
|
||||
manager = RunManager(store=store)
|
||||
|
||||
with pytest.raises(RuntimeError, match="db down"):
|
||||
await manager.create("thread-1")
|
||||
|
||||
assert manager._runs == {}
|
||||
assert await manager.list_by_thread("thread-1") == []
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_rolls_back_in_memory_record_on_store_cancellation():
|
||||
"""create() must also roll back when cancelled during the initial store write."""
|
||||
store = MemoryRunStore()
|
||||
|
||||
async def cancelled_put(run_id, **kwargs):
|
||||
raise asyncio.CancelledError
|
||||
|
||||
store.put = cancelled_put
|
||||
manager = RunManager(store=store)
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await manager.create("thread-1")
|
||||
|
||||
assert manager._runs == {}
|
||||
assert await manager.list_by_thread("thread-1") == []
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_does_not_expose_run_until_store_persist_completes():
|
||||
"""Concurrent readers must wait until the new run has been persisted."""
|
||||
store = MemoryRunStore()
|
||||
manager = RunManager(store=store)
|
||||
original_put = store.put
|
||||
put_started = asyncio.Event()
|
||||
allow_put = asyncio.Event()
|
||||
|
||||
async def blocking_put(run_id, **kwargs):
|
||||
put_started.set()
|
||||
await allow_put.wait()
|
||||
return await original_put(run_id, **kwargs)
|
||||
|
||||
store.put = blocking_put
|
||||
create_task = asyncio.create_task(manager.create("thread-1"))
|
||||
list_task = None
|
||||
|
||||
try:
|
||||
await put_started.wait()
|
||||
list_task = asyncio.create_task(manager.list_by_thread("thread-1"))
|
||||
await asyncio.sleep(0)
|
||||
assert not list_task.done()
|
||||
|
||||
allow_put.set()
|
||||
record = await create_task
|
||||
runs = await list_task
|
||||
|
||||
assert [run.run_id for run in runs] == [record.run_id]
|
||||
finally:
|
||||
allow_put.set()
|
||||
cleanup_tasks = []
|
||||
for task in (list_task, create_task):
|
||||
if task is None:
|
||||
continue
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
cleanup_tasks.append(task)
|
||||
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_prefers_in_memory_record_over_store():
|
||||
"""In-memory records retain task/control state when store has same run."""
|
||||
@@ -558,6 +634,52 @@ async def test_create_or_reject_interrupt_persists_interrupted_status_to_store()
|
||||
assert stored_old["status"] == "interrupted"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_or_reject_does_not_interrupt_old_run_when_new_run_store_write_fails():
|
||||
"""A failed new-run persist must not cancel the existing inflight run."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
store = MemoryRunStore()
|
||||
manager = RunManager(store=store)
|
||||
old = await manager.create("thread-1")
|
||||
await manager.set_status(old.run_id, RunStatus.running)
|
||||
store.put = AsyncMock(side_effect=RuntimeError("db down"))
|
||||
|
||||
with pytest.raises(RuntimeError, match="db down"):
|
||||
await manager.create_or_reject("thread-1", multitask_strategy="interrupt")
|
||||
|
||||
stored_old = await store.get(old.run_id)
|
||||
assert list(manager._runs) == [old.run_id]
|
||||
assert old.status == RunStatus.running
|
||||
assert old.abort_event.is_set() is False
|
||||
assert stored_old is not None
|
||||
assert stored_old["status"] == "running"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_or_reject_does_not_interrupt_old_run_when_new_run_store_write_is_cancelled():
|
||||
"""Cancellation during new-run persist must not cancel the existing run."""
|
||||
store = MemoryRunStore()
|
||||
manager = RunManager(store=store)
|
||||
old = await manager.create("thread-1")
|
||||
await manager.set_status(old.run_id, RunStatus.running)
|
||||
|
||||
async def cancelled_put(run_id, **kwargs):
|
||||
raise asyncio.CancelledError
|
||||
|
||||
store.put = cancelled_put
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await manager.create_or_reject("thread-1", multitask_strategy="interrupt")
|
||||
|
||||
stored_old = await store.get(old.run_id)
|
||||
assert list(manager._runs) == [old.run_id]
|
||||
assert old.status == RunStatus.running
|
||||
assert old.abort_event.is_set() is False
|
||||
assert stored_old is not None
|
||||
assert stored_old["status"] == "running"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_or_reject_rollback_persists_interrupted_status_to_store():
|
||||
"""rollback strategy should persist interrupted status for old runs."""
|
||||
|
||||
Reference in New Issue
Block a user