mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 11a362e5e5 | |||
| 85402405ec | |||
| 43eb643910 | |||
| f3e3a350ce | |||
| 0fae7c9cbb |
@@ -53,24 +53,27 @@ class RunManager:
|
|||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
self._store = store
|
self._store = store
|
||||||
|
|
||||||
async def _persist_to_store(self, record: RunRecord) -> None:
|
async def _persist_new_run_to_store(self, record: RunRecord) -> None:
|
||||||
"""Best-effort persist run record to backing store."""
|
"""Persist a newly created run record to the backing store.
|
||||||
|
|
||||||
|
Initial run creation is part of the run visibility boundary: callers
|
||||||
|
should not observe a run in memory unless its backing store row exists.
|
||||||
|
Unlike follow-up status/model updates, failures are propagated so the
|
||||||
|
caller can treat creation as failed.
|
||||||
|
"""
|
||||||
if self._store is None:
|
if self._store is None:
|
||||||
return
|
return
|
||||||
try:
|
await self._store.put(
|
||||||
await self._store.put(
|
record.run_id,
|
||||||
record.run_id,
|
thread_id=record.thread_id,
|
||||||
thread_id=record.thread_id,
|
assistant_id=record.assistant_id,
|
||||||
assistant_id=record.assistant_id,
|
status=record.status.value,
|
||||||
status=record.status.value,
|
multitask_strategy=record.multitask_strategy,
|
||||||
multitask_strategy=record.multitask_strategy,
|
metadata=record.metadata or {},
|
||||||
metadata=record.metadata or {},
|
kwargs=record.kwargs or {},
|
||||||
kwargs=record.kwargs or {},
|
created_at=record.created_at,
|
||||||
created_at=record.created_at,
|
model_name=record.model_name,
|
||||||
model_name=record.model_name,
|
)
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
|
|
||||||
|
|
||||||
async def _persist_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
|
async def _persist_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
|
||||||
"""Best-effort persist a status transition to the backing store."""
|
"""Best-effort persist a status transition to the backing store."""
|
||||||
@@ -139,7 +142,16 @@ class RunManager:
|
|||||||
)
|
)
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._runs[run_id] = record
|
self._runs[run_id] = record
|
||||||
await self._persist_to_store(record)
|
persisted = False
|
||||||
|
try:
|
||||||
|
await self._persist_new_run_to_store(record)
|
||||||
|
persisted = True
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to persist run %s; rolled back in-memory record", run_id, exc_info=True)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if not persisted:
|
||||||
|
self._runs.pop(run_id, None)
|
||||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
@@ -317,16 +329,8 @@ class RunManager:
|
|||||||
raise ConflictError(f"Thread {thread_id} already has an active run")
|
raise ConflictError(f"Thread {thread_id} already has an active run")
|
||||||
|
|
||||||
if multitask_strategy in ("interrupt", "rollback") and inflight:
|
if multitask_strategy in ("interrupt", "rollback") and inflight:
|
||||||
for r in inflight:
|
|
||||||
r.abort_action = multitask_strategy
|
|
||||||
r.abort_event.set()
|
|
||||||
if r.task is not None and not r.task.done():
|
|
||||||
r.task.cancel()
|
|
||||||
r.status = RunStatus.interrupted
|
|
||||||
r.updated_at = now
|
|
||||||
interrupted_run_ids.append(r.run_id)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Cancelled %d inflight run(s) on thread %s (strategy=%s)",
|
"Preparing to cancel %d inflight run(s) on thread %s (strategy=%s)",
|
||||||
len(inflight),
|
len(inflight),
|
||||||
thread_id,
|
thread_id,
|
||||||
multitask_strategy,
|
multitask_strategy,
|
||||||
@@ -346,10 +350,29 @@ class RunManager:
|
|||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
)
|
)
|
||||||
self._runs[run_id] = record
|
self._runs[run_id] = record
|
||||||
|
persisted = False
|
||||||
|
try:
|
||||||
|
await self._persist_new_run_to_store(record)
|
||||||
|
persisted = True
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to persist run %s; rolled back in-memory record", run_id, exc_info=True)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if not persisted:
|
||||||
|
self._runs.pop(run_id, None)
|
||||||
|
|
||||||
|
if multitask_strategy in ("interrupt", "rollback") and inflight:
|
||||||
|
for r in inflight:
|
||||||
|
r.abort_action = multitask_strategy
|
||||||
|
r.abort_event.set()
|
||||||
|
if r.task is not None and not r.task.done():
|
||||||
|
r.task.cancel()
|
||||||
|
r.status = RunStatus.interrupted
|
||||||
|
r.updated_at = now
|
||||||
|
interrupted_run_ids.append(r.run_id)
|
||||||
|
|
||||||
for interrupted_run_id in interrupted_run_ids:
|
for interrupted_run_id in interrupted_run_ids:
|
||||||
await self._persist_status(interrupted_run_id, RunStatus.interrupted)
|
await self._persist_status(interrupted_run_id, RunStatus.interrupted)
|
||||||
await self._persist_to_store(record)
|
|
||||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Tests for RunManager."""
|
"""Tests for RunManager."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -231,6 +232,81 @@ async def test_create_record_is_not_store_only(manager: RunManager):
|
|||||||
assert record.store_only is False
|
assert record.store_only is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_rolls_back_in_memory_record_on_store_failure():
|
||||||
|
"""create() must fail and hide the run when the initial store write fails."""
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
store = MemoryRunStore()
|
||||||
|
store.put = AsyncMock(side_effect=RuntimeError("db down"))
|
||||||
|
manager = RunManager(store=store)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="db down"):
|
||||||
|
await manager.create("thread-1")
|
||||||
|
|
||||||
|
assert manager._runs == {}
|
||||||
|
assert await manager.list_by_thread("thread-1") == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_rolls_back_in_memory_record_on_store_cancellation():
|
||||||
|
"""create() must also roll back when cancelled during the initial store write."""
|
||||||
|
store = MemoryRunStore()
|
||||||
|
|
||||||
|
async def cancelled_put(run_id, **kwargs):
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
store.put = cancelled_put
|
||||||
|
manager = RunManager(store=store)
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await manager.create("thread-1")
|
||||||
|
|
||||||
|
assert manager._runs == {}
|
||||||
|
assert await manager.list_by_thread("thread-1") == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_does_not_expose_run_until_store_persist_completes():
|
||||||
|
"""Concurrent readers must wait until the new run has been persisted."""
|
||||||
|
store = MemoryRunStore()
|
||||||
|
manager = RunManager(store=store)
|
||||||
|
original_put = store.put
|
||||||
|
put_started = asyncio.Event()
|
||||||
|
allow_put = asyncio.Event()
|
||||||
|
|
||||||
|
async def blocking_put(run_id, **kwargs):
|
||||||
|
put_started.set()
|
||||||
|
await allow_put.wait()
|
||||||
|
return await original_put(run_id, **kwargs)
|
||||||
|
|
||||||
|
store.put = blocking_put
|
||||||
|
create_task = asyncio.create_task(manager.create("thread-1"))
|
||||||
|
list_task = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
await put_started.wait()
|
||||||
|
list_task = asyncio.create_task(manager.list_by_thread("thread-1"))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert not list_task.done()
|
||||||
|
|
||||||
|
allow_put.set()
|
||||||
|
record = await create_task
|
||||||
|
runs = await list_task
|
||||||
|
|
||||||
|
assert [run.run_id for run in runs] == [record.run_id]
|
||||||
|
finally:
|
||||||
|
allow_put.set()
|
||||||
|
cleanup_tasks = []
|
||||||
|
for task in (list_task, create_task):
|
||||||
|
if task is None:
|
||||||
|
continue
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
cleanup_tasks.append(task)
|
||||||
|
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_prefers_in_memory_record_over_store():
|
async def test_get_prefers_in_memory_record_over_store():
|
||||||
"""In-memory records retain task/control state when store has same run."""
|
"""In-memory records retain task/control state when store has same run."""
|
||||||
@@ -318,6 +394,52 @@ async def test_create_or_reject_interrupt_persists_interrupted_status_to_store()
|
|||||||
assert stored_old["status"] == "interrupted"
|
assert stored_old["status"] == "interrupted"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_or_reject_does_not_interrupt_old_run_when_new_run_store_write_fails():
|
||||||
|
"""A failed new-run persist must not cancel the existing inflight run."""
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
store = MemoryRunStore()
|
||||||
|
manager = RunManager(store=store)
|
||||||
|
old = await manager.create("thread-1")
|
||||||
|
await manager.set_status(old.run_id, RunStatus.running)
|
||||||
|
store.put = AsyncMock(side_effect=RuntimeError("db down"))
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="db down"):
|
||||||
|
await manager.create_or_reject("thread-1", multitask_strategy="interrupt")
|
||||||
|
|
||||||
|
stored_old = await store.get(old.run_id)
|
||||||
|
assert list(manager._runs) == [old.run_id]
|
||||||
|
assert old.status == RunStatus.running
|
||||||
|
assert old.abort_event.is_set() is False
|
||||||
|
assert stored_old is not None
|
||||||
|
assert stored_old["status"] == "running"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_or_reject_does_not_interrupt_old_run_when_new_run_store_write_is_cancelled():
|
||||||
|
"""Cancellation during new-run persist must not cancel the existing run."""
|
||||||
|
store = MemoryRunStore()
|
||||||
|
manager = RunManager(store=store)
|
||||||
|
old = await manager.create("thread-1")
|
||||||
|
await manager.set_status(old.run_id, RunStatus.running)
|
||||||
|
|
||||||
|
async def cancelled_put(run_id, **kwargs):
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
store.put = cancelled_put
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await manager.create_or_reject("thread-1", multitask_strategy="interrupt")
|
||||||
|
|
||||||
|
stored_old = await store.get(old.run_id)
|
||||||
|
assert list(manager._runs) == [old.run_id]
|
||||||
|
assert old.status == RunStatus.running
|
||||||
|
assert old.abort_event.is_set() is False
|
||||||
|
assert stored_old is not None
|
||||||
|
assert stored_old["status"] == "running"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_create_or_reject_rollback_persists_interrupted_status_to_store():
|
async def test_create_or_reject_rollback_persists_interrupted_status_to_store():
|
||||||
"""rollback strategy should persist interrupted status for old runs."""
|
"""rollback strategy should persist interrupted status for old runs."""
|
||||||
|
|||||||
Reference in New Issue
Block a user