Files
deer-flow/backend/packages/harness/deerflow/runtime/runs/manager.py
T
rayhpeng 0fb05825a2 fix(runtime): make run creation persistence atomic (#3152)
* fix runtime run creation persistence atomicity

* fix run creation cancellation rollback

* fix run manager test cleanup await

* clarify run creation rollback on cancellation

* document new run persistence rollback boundary

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-23 22:43:34 +08:00

655 lines
26 KiB
Python

"""In-memory run registry with optional persistent RunStore backing."""
from __future__ import annotations
import asyncio
import logging
import sqlite3
import uuid
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from deerflow.utils.time import now_iso as _now_iso
from .schemas import DisconnectMode, RunStatus
if TYPE_CHECKING:
from deerflow.runtime.runs.store.base import RunStore
logger = logging.getLogger(__name__)
_RETRYABLE_SQLITE_MESSAGES = (
"database is locked",
"database table is locked",
"database is busy",
)
_RETRYABLE_SQLITE_ERROR_CODES = {
sqlite3.SQLITE_BUSY,
sqlite3.SQLITE_LOCKED,
}
def _is_retryable_persistence_error(exc: BaseException) -> bool:
"""Return True for transient SQLite persistence failures.
SQLite lock contention normally surfaces through either sqlite3 exceptions
or SQLAlchemy wrappers. The short bounded retry here protects run status
finalization from transient writer pressure without hiding permanent
failures forever.
"""
pending: list[BaseException] = [exc]
seen: set[int] = set()
while pending:
current = pending.pop()
if id(current) in seen:
continue
seen.add(id(current))
message = str(current).lower()
if any(fragment in message for fragment in _RETRYABLE_SQLITE_MESSAGES):
return True
if isinstance(current, (sqlite3.OperationalError, sqlite3.DatabaseError)):
error_code = getattr(current, "sqlite_errorcode", None)
if error_code in _RETRYABLE_SQLITE_ERROR_CODES:
return True
for chained in (getattr(current, "orig", None), current.__cause__, current.__context__):
if isinstance(chained, BaseException):
pending.append(chained)
return False
@dataclass(frozen=True)
class PersistenceRetryPolicy:
"""Bounded retry policy for short run-store writes."""
max_attempts: int = 5
initial_delay: float = 0.05
max_delay: float = 1.0
backoff_factor: float = 2.0
@dataclass
class RunRecord:
"""Mutable record for a single run."""
run_id: str
thread_id: str
assistant_id: str | None
status: RunStatus
on_disconnect: DisconnectMode
multitask_strategy: str = "reject"
metadata: dict = field(default_factory=dict)
kwargs: dict = field(default_factory=dict)
created_at: str = ""
updated_at: str = ""
task: asyncio.Task | None = field(default=None, repr=False)
abort_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
abort_action: str = "interrupt"
error: str | None = None
model_name: str | None = None
store_only: bool = False
total_input_tokens: int = 0
total_output_tokens: int = 0
total_tokens: int = 0
llm_call_count: int = 0
lead_agent_tokens: int = 0
subagent_tokens: int = 0
middleware_tokens: int = 0
message_count: int = 0
last_ai_message: str | None = None
first_human_message: str | None = None
class RunManager:
"""In-memory run registry with optional persistent RunStore backing.
All mutations are protected by an asyncio lock. When a ``store`` is
provided, serializable metadata is also persisted to the store so
that run history survives process restarts.
"""
def __init__(
self,
store: RunStore | None = None,
*,
persistence_retry_policy: PersistenceRetryPolicy | None = None,
) -> None:
self._runs: dict[str, RunRecord] = {}
self._lock = asyncio.Lock()
self._store = store
self._persistence_retry_policy = persistence_retry_policy or PersistenceRetryPolicy()
@staticmethod
def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]:
return {
"thread_id": record.thread_id,
"assistant_id": record.assistant_id,
"status": record.status.value,
"multitask_strategy": record.multitask_strategy,
"metadata": record.metadata or {},
"kwargs": record.kwargs or {},
"error": error if error is not None else record.error,
"created_at": record.created_at,
"model_name": record.model_name,
}
async def _call_store_with_retry(
self,
operation_name: str,
run_id: str,
operation: Callable[[], Awaitable[Any]],
) -> Any:
"""Run a short store operation with bounded retries for SQLite pressure."""
policy = self._persistence_retry_policy
attempt = 1
delay = policy.initial_delay
while True:
try:
return await operation()
except Exception as exc:
retryable = _is_retryable_persistence_error(exc)
if attempt >= policy.max_attempts or not retryable:
raise
logger.warning(
"Transient persistence failure during %s for run %s (attempt %d/%d); retrying",
operation_name,
run_id,
attempt,
policy.max_attempts,
exc_info=True,
)
if delay > 0:
await asyncio.sleep(delay)
delay = min(policy.max_delay, delay * policy.backoff_factor if delay else policy.initial_delay)
attempt += 1
async def _persist_snapshot_to_store(self, run_id: str, payload: dict[str, Any]) -> bool:
"""Best-effort persist a previously captured run snapshot."""
if self._store is None:
return True
try:
await self._call_store_with_retry(
"put",
run_id,
lambda: self._store.put(run_id, **payload),
)
return True
except Exception:
logger.warning("Failed to persist run %s to store", run_id, exc_info=True)
return False
async def _persist_new_run_to_store(self, record: RunRecord) -> None:
"""Persist a newly created run record to the backing store.
Initial run creation is part of the run visibility boundary: callers
should not observe a run in memory unless its backing store row exists.
Unlike follow-up status/model updates, failures are propagated so the
caller can treat creation as failed. Rollback is the caller's
responsibility after inserting the record into ``_runs``.
"""
if self._store is None:
return
await self._call_store_with_retry(
"put",
record.run_id,
lambda: self._store.put(record.run_id, **self._store_put_payload(record)),
)
async def _persist_to_store(self, record: RunRecord, *, error: str | None = None) -> bool:
"""Best-effort persist run record to backing store."""
return await self._persist_snapshot_to_store(
record.run_id,
self._store_put_payload(record, error=error),
)
async def _persist_status(self, record: RunRecord, status: RunStatus, *, error: str | None = None) -> bool:
"""Best-effort persist a status transition to the backing store."""
if self._store is None:
return True
row_recovery_payload = self._store_put_payload(record, error=error)
try:
updated = await self._call_store_with_retry(
"update_status",
record.run_id,
lambda: self._store.update_status(record.run_id, status.value, error=error),
)
if updated is False:
return await self._persist_snapshot_to_store(record.run_id, row_recovery_payload)
return True
except Exception:
logger.warning("Failed to persist status update for run %s", record.run_id, exc_info=True)
return False
@staticmethod
def _record_from_store(row: dict[str, Any]) -> RunRecord:
"""Build a read-only runtime record from a serialized store row.
NULL status/on_disconnect columns (e.g. from rows written before those
columns were added) default to ``pending`` and ``cancel`` respectively.
"""
return RunRecord(
run_id=row["run_id"],
thread_id=row["thread_id"],
assistant_id=row.get("assistant_id"),
status=RunStatus(row.get("status") or RunStatus.pending.value),
on_disconnect=DisconnectMode(row.get("on_disconnect") or DisconnectMode.cancel.value),
multitask_strategy=row.get("multitask_strategy") or "reject",
metadata=row.get("metadata") or {},
kwargs=row.get("kwargs") or {},
created_at=row.get("created_at") or "",
updated_at=row.get("updated_at") or "",
error=row.get("error"),
model_name=row.get("model_name"),
store_only=True,
total_input_tokens=row.get("total_input_tokens") or 0,
total_output_tokens=row.get("total_output_tokens") or 0,
total_tokens=row.get("total_tokens") or 0,
llm_call_count=row.get("llm_call_count") or 0,
lead_agent_tokens=row.get("lead_agent_tokens") or 0,
subagent_tokens=row.get("subagent_tokens") or 0,
middleware_tokens=row.get("middleware_tokens") or 0,
message_count=row.get("message_count") or 0,
last_ai_message=row.get("last_ai_message"),
first_human_message=row.get("first_human_message"),
)
async def update_run_completion(self, run_id: str, **kwargs) -> None:
"""Persist token usage and completion data to the backing store."""
row_recovery_payload: dict[str, Any] | None = None
async with self._lock:
record = self._runs.get(run_id)
if record is not None:
for key, value in kwargs.items():
if key == "status":
continue
if hasattr(record, key) and value is not None:
setattr(record, key, value)
record.updated_at = _now_iso()
row_recovery_payload = self._store_put_payload(record, error=kwargs.get("error"))
if self._store is None:
return
try:
updated = await self._call_store_with_retry(
"update_run_completion",
run_id,
lambda: self._store.update_run_completion(run_id, **kwargs),
)
if updated is False:
if row_recovery_payload is None:
logger.warning("Failed to recreate missing run %s for completion persistence", run_id)
return
if not await self._persist_snapshot_to_store(run_id, row_recovery_payload):
return
recovered = await self._call_store_with_retry(
"update_run_completion",
run_id,
lambda: self._store.update_run_completion(run_id, **kwargs),
)
if recovered is False:
logger.warning("Run completion update for %s affected no rows after row recreation", run_id)
except Exception:
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
async def update_run_progress(self, run_id: str, **kwargs) -> None:
"""Persist a running token/message snapshot without changing status."""
should_persist = True
async with self._lock:
record = self._runs.get(run_id)
if record is not None:
should_persist = record.status == RunStatus.running
if record is not None and should_persist:
for key, value in kwargs.items():
if hasattr(record, key) and value is not None:
setattr(record, key, value)
record.updated_at = _now_iso()
if should_persist and self._store is not None:
try:
await self._store.update_run_progress(run_id, **kwargs)
except Exception:
logger.warning("Failed to persist run progress for %s", run_id, exc_info=True)
async def create(
self,
thread_id: str,
assistant_id: str | None = None,
*,
on_disconnect: DisconnectMode = DisconnectMode.cancel,
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
) -> RunRecord:
"""Create a new pending run and register it."""
run_id = str(uuid.uuid4())
now = _now_iso()
record = RunRecord(
run_id=run_id,
thread_id=thread_id,
assistant_id=assistant_id,
status=RunStatus.pending,
on_disconnect=on_disconnect,
multitask_strategy=multitask_strategy,
metadata=metadata or {},
kwargs=kwargs or {},
created_at=now,
updated_at=now,
)
async with self._lock:
self._runs[run_id] = record
persisted = False
try:
await self._persist_new_run_to_store(record)
persisted = True
except Exception:
logger.warning("Failed to persist run %s; rolled back in-memory record", run_id, exc_info=True)
raise
finally:
# Also covers cancellation, which bypasses ``except Exception``.
if not persisted:
self._runs.pop(run_id, None)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record
async def get(self, run_id: str, *, user_id: str | None = None) -> RunRecord | None:
"""Return a run record by ID, or ``None``.
Args:
run_id: The run ID to look up.
user_id: Optional user ID for permission filtering when hydrating from store.
"""
async with self._lock:
record = self._runs.get(run_id)
if record is not None:
return record
if self._store is None:
return None
try:
row = await self._store.get(run_id, user_id=user_id)
except Exception:
logger.warning("Failed to hydrate run %s from store", run_id, exc_info=True)
return None
# Re-check after store await: a concurrent create() may have inserted the
# in-memory record while the store call was in flight.
async with self._lock:
record = self._runs.get(run_id)
if record is not None:
return record
if row is None:
return None
try:
return self._record_from_store(row)
except Exception:
logger.warning("Failed to map store row for run %s", run_id, exc_info=True)
return None
async def aget(self, run_id: str, *, user_id: str | None = None) -> RunRecord | None:
"""Return a run record by ID, checking the persistent store as fallback.
Alias for :meth:`get` for backward compatibility.
"""
return await self.get(run_id, user_id=user_id)
async def list_by_thread(self, thread_id: str, *, user_id: str | None = None, limit: int = 100) -> list[RunRecord]:
"""Return runs for a given thread, newest first, at most ``limit`` records.
In-memory runs take precedence only when the same ``run_id`` exists in both
memory and the backing store. The merged result is then sorted newest-first
by ``created_at`` and trimmed to ``limit`` (default 100).
Args:
thread_id: The thread ID to filter by.
user_id: Optional user ID for permission filtering when hydrating from store.
limit: Maximum number of runs to return.
"""
async with self._lock:
# Dict insertion order gives deterministic results when timestamps tie.
memory_records = [r for r in self._runs.values() if r.thread_id == thread_id]
if self._store is None:
return sorted(memory_records, key=lambda r: r.created_at, reverse=True)[:limit]
records_by_id = {record.run_id: record for record in memory_records}
store_limit = max(0, limit - len(memory_records))
try:
rows = await self._store.list_by_thread(thread_id, user_id=user_id, limit=store_limit)
except Exception:
logger.warning("Failed to hydrate runs for thread %s from store", thread_id, exc_info=True)
return sorted(memory_records, key=lambda r: r.created_at, reverse=True)[:limit]
for row in rows:
run_id = row.get("run_id")
if run_id and run_id not in records_by_id:
try:
records_by_id[run_id] = self._record_from_store(row)
except Exception:
logger.warning("Failed to map store row for run %s", run_id, exc_info=True)
return sorted(records_by_id.values(), key=lambda record: record.created_at, reverse=True)[:limit]
async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
"""Transition a run to a new status."""
async with self._lock:
record = self._runs.get(run_id)
if record is None:
logger.warning("set_status called for unknown run %s", run_id)
return
record.status = status
record.updated_at = _now_iso()
if error is not None:
record.error = error
await self._persist_status(record, status, error=error)
logger.info("Run %s -> %s", run_id, status.value)
async def _persist_model_name(self, run_id: str, model_name: str | None) -> None:
"""Best-effort persist model_name update to the backing store."""
if self._store is None:
return
try:
await self._call_store_with_retry(
"update_model_name",
run_id,
lambda: self._store.update_model_name(run_id, model_name),
)
except Exception:
logger.warning("Failed to persist model_name update for run %s", run_id, exc_info=True)
async def update_model_name(self, run_id: str, model_name: str | None) -> None:
"""Update the model name for a run."""
async with self._lock:
record = self._runs.get(run_id)
if record is None:
logger.warning("update_model_name called for unknown run %s", run_id)
return
record.model_name = model_name
record.updated_at = _now_iso()
await self._persist_model_name(run_id, model_name)
logger.info("Run %s model_name=%s", run_id, model_name)
async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool:
"""Request cancellation of a run.
Args:
run_id: The run ID to cancel.
action: "interrupt" keeps checkpoint, "rollback" reverts to pre-run state.
Sets the abort event with the action reason and cancels the asyncio task.
Returns ``True`` if cancellation was initiated **or** the run was already
interrupted (idempotent — a second cancel is a no-op success).
Returns ``False`` only when the run is unknown to this worker or has
reached a terminal state other than interrupted (completed, failed, etc.).
"""
async with self._lock:
record = self._runs.get(run_id)
if record is None:
return False
if record.status == RunStatus.interrupted:
return True # idempotent — already cancelled on this worker
if record.status not in (RunStatus.pending, RunStatus.running):
return False
record.abort_action = action
record.abort_event.set()
if record.task is not None and not record.task.done():
record.task.cancel()
record.status = RunStatus.interrupted
record.updated_at = _now_iso()
await self._persist_status(record, RunStatus.interrupted)
logger.info("Run %s cancelled (action=%s)", run_id, action)
return True
async def create_or_reject(
self,
thread_id: str,
assistant_id: str | None = None,
*,
on_disconnect: DisconnectMode = DisconnectMode.cancel,
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
model_name: str | None = None,
) -> RunRecord:
"""Atomically check for inflight runs and create a new one.
For ``reject`` strategy, raises ``ConflictError`` if thread
already has a pending/running run. For ``interrupt``/``rollback``,
cancels inflight runs before creating.
This method holds the lock across both the check and the insert,
eliminating the TOCTOU race in separate ``has_inflight`` + ``create``.
"""
run_id = str(uuid.uuid4())
now = _now_iso()
_supported_strategies = ("reject", "interrupt", "rollback")
interrupted_records: list[RunRecord] = []
async with self._lock:
if multitask_strategy not in _supported_strategies:
raise UnsupportedStrategyError(f"Multitask strategy '{multitask_strategy}' is not yet supported. Supported strategies: {', '.join(_supported_strategies)}")
inflight = [r for r in self._runs.values() if r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running)]
if multitask_strategy == "reject" and inflight:
raise ConflictError(f"Thread {thread_id} already has an active run")
if multitask_strategy in ("interrupt", "rollback") and inflight:
logger.info(
"Preparing to cancel %d inflight run(s) on thread %s (strategy=%s)",
len(inflight),
thread_id,
multitask_strategy,
)
record = RunRecord(
run_id=run_id,
thread_id=thread_id,
assistant_id=assistant_id,
status=RunStatus.pending,
on_disconnect=on_disconnect,
multitask_strategy=multitask_strategy,
metadata=metadata or {},
kwargs=kwargs or {},
created_at=now,
updated_at=now,
model_name=model_name,
)
self._runs[run_id] = record
persisted = False
try:
await self._persist_new_run_to_store(record)
persisted = True
except Exception:
logger.warning("Failed to persist run %s; rolled back in-memory record", run_id, exc_info=True)
raise
finally:
# Also covers cancellation, which bypasses ``except Exception``.
if not persisted:
self._runs.pop(run_id, None)
if multitask_strategy in ("interrupt", "rollback") and inflight:
for r in inflight:
r.abort_action = multitask_strategy
r.abort_event.set()
if r.task is not None and not r.task.done():
r.task.cancel()
r.status = RunStatus.interrupted
r.updated_at = now
interrupted_records.append(r)
for interrupted_record in interrupted_records:
await self._persist_status(interrupted_record, RunStatus.interrupted)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record
async def reconcile_orphaned_inflight_runs(
self,
*,
error: str,
before: str | None = None,
) -> list[RunRecord]:
"""Mark persisted active runs as failed when no local task owns them.
Gateway runs are process-local: the asyncio task and abort event live in
memory, while the run row is durable. After a SQLite-backed gateway
restart, any persisted ``pending`` or ``running`` row created before
startup cannot still have a local worker. This recovery step turns that
ambiguous state into an explicit error instead of letting the UI show an
indefinite active run.
"""
if self._store is None:
return []
try:
rows = await self._call_store_with_retry(
"list_inflight",
"*",
lambda: self._store.list_inflight(before=before),
)
except Exception:
logger.warning("Failed to list orphaned inflight runs for reconciliation", exc_info=True)
return []
recovered: list[RunRecord] = []
now = _now_iso()
for row in rows:
try:
record = self._record_from_store(row)
except Exception:
logger.warning("Failed to map orphaned run row during reconciliation", exc_info=True)
continue
async with self._lock:
live_record = self._runs.get(record.run_id)
if live_record is not None and live_record.status in (RunStatus.pending, RunStatus.running):
continue
record.status = RunStatus.error
record.error = error
record.updated_at = now
persisted = await self._persist_status(record, RunStatus.error, error=error)
if not persisted:
logger.warning("Skipped orphaned run %s recovery because error status was not persisted", record.run_id)
continue
recovered.append(record)
if recovered:
logger.warning("Recovered %d orphaned inflight run(s) as error", len(recovered))
return recovered
async def has_inflight(self, thread_id: str) -> bool:
"""Return ``True`` if *thread_id* has a pending or running run."""
async with self._lock:
return any(r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running) for r in self._runs.values())
async def cleanup(self, run_id: str, *, delay: float = 300) -> None:
"""Remove a run record after an optional delay."""
if delay > 0:
await asyncio.sleep(delay)
async with self._lock:
self._runs.pop(run_id, None)
logger.debug("Run record %s cleaned up", run_id)
class ConflictError(Exception):
"""Raised when multitask_strategy=reject and thread has inflight runs."""
class UnsupportedStrategyError(Exception):
"""Raised when a multitask_strategy value is not yet implemented."""