diff --git a/backend/packages/harness/deerflow/runtime/runs/manager.py b/backend/packages/harness/deerflow/runtime/runs/manager.py index cce8e3bf5..ae8d89289 100644 --- a/backend/packages/harness/deerflow/runtime/runs/manager.py +++ b/backend/packages/harness/deerflow/runtime/runs/manager.py @@ -144,6 +144,9 @@ class RunManager: self._runs[run_id] = record try: await self._persist_new_run_to_store(record) + except asyncio.CancelledError: + self._runs.pop(run_id, None) + raise except Exception: self._runs.pop(run_id, None) logger.warning("Failed to persist run %s; rolled back in-memory record", run_id, exc_info=True) @@ -348,6 +351,9 @@ class RunManager: self._runs[run_id] = record try: await self._persist_new_run_to_store(record) + except asyncio.CancelledError: + self._runs.pop(run_id, None) + raise except Exception: self._runs.pop(run_id, None) logger.warning("Failed to persist run %s; rolled back in-memory record", run_id, exc_info=True) diff --git a/backend/tests/test_run_manager.py b/backend/tests/test_run_manager.py index 8dbd277b6..4b88c8501 100644 --- a/backend/tests/test_run_manager.py +++ b/backend/tests/test_run_manager.py @@ -2,6 +2,7 @@ import asyncio import re +from contextlib import suppress import pytest @@ -248,6 +249,24 @@ async def test_create_rolls_back_in_memory_record_on_store_failure(): assert await manager.list_by_thread("thread-1") == [] +@pytest.mark.anyio +async def test_create_rolls_back_in_memory_record_on_store_cancellation(): + """create() must also roll back when cancelled during the initial store write.""" + store = MemoryRunStore() + + async def cancelled_put(run_id, **kwargs): + raise asyncio.CancelledError + + store.put = cancelled_put + manager = RunManager(store=store) + + with pytest.raises(asyncio.CancelledError): + await manager.create("thread-1") + + assert manager._runs == {} + assert await manager.list_by_thread("thread-1") == [] + + @pytest.mark.anyio async def test_create_does_not_expose_run_until_store_persist_completes(): """Concurrent readers must wait until the new run has been persisted.""" @@ -264,6 +283,7 @@ async def test_create_does_not_expose_run_until_store_persist_completes(): store.put = blocking_put create_task = asyncio.create_task(manager.create("thread-1")) + list_task = None try: await put_started.wait() @@ -278,8 +298,13 @@ async def test_create_does_not_expose_run_until_store_persist_completes(): assert [run.run_id for run in runs] == [record.run_id] finally: allow_put.set() - if not create_task.done(): - create_task.cancel() + for task in (list_task, create_task): + if task is None: + continue + if not task.done(): + task.cancel() + with suppress(asyncio.CancelledError): + await task @pytest.mark.anyio @@ -391,6 +416,30 @@ async def test_create_or_reject_does_not_interrupt_old_run_when_new_run_store_wr assert stored_old["status"] == "running" +@pytest.mark.anyio +async def test_create_or_reject_does_not_interrupt_old_run_when_new_run_store_write_is_cancelled(): + """Cancellation during new-run persist must not cancel the existing run.""" + store = MemoryRunStore() + manager = RunManager(store=store) + old = await manager.create("thread-1") + await manager.set_status(old.run_id, RunStatus.running) + + async def cancelled_put(run_id, **kwargs): + raise asyncio.CancelledError + + store.put = cancelled_put + + with pytest.raises(asyncio.CancelledError): + await manager.create_or_reject("thread-1", multitask_strategy="interrupt") + + stored_old = await store.get(old.run_id) + assert list(manager._runs) == [old.run_id] + assert old.status == RunStatus.running + assert old.abort_event.is_set() is False + assert stored_old is not None + assert stored_old["status"] == "running" + + @pytest.mark.anyio async def test_create_or_reject_rollback_persists_interrupted_status_to_store(): """rollback strategy should persist interrupted status for old runs."""