"""In-memory run registry with optional persistent RunStore backing.""" from __future__ import annotations import asyncio import logging import sqlite3 import uuid from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any from deerflow.utils.time import now_iso as _now_iso from .schemas import DisconnectMode, RunStatus if TYPE_CHECKING: from deerflow.runtime.runs.store.base import RunStore logger = logging.getLogger(__name__) _RETRYABLE_SQLITE_MESSAGES = ( "database is locked", "database table is locked", "database is busy", ) _RETRYABLE_SQLITE_ERROR_CODES = { sqlite3.SQLITE_BUSY, sqlite3.SQLITE_LOCKED, } def _is_retryable_persistence_error(exc: BaseException) -> bool: """Return True for transient SQLite persistence failures. SQLite lock contention normally surfaces through either sqlite3 exceptions or SQLAlchemy wrappers. The short bounded retry here protects run status finalization from transient writer pressure without hiding permanent failures forever. """ pending: list[BaseException] = [exc] seen: set[int] = set() while pending: current = pending.pop() if id(current) in seen: continue seen.add(id(current)) message = str(current).lower() if any(fragment in message for fragment in _RETRYABLE_SQLITE_MESSAGES): return True if isinstance(current, (sqlite3.OperationalError, sqlite3.DatabaseError)): error_code = getattr(current, "sqlite_errorcode", None) if error_code in _RETRYABLE_SQLITE_ERROR_CODES: return True for chained in (getattr(current, "orig", None), current.__cause__, current.__context__): if isinstance(chained, BaseException): pending.append(chained) return False @dataclass(frozen=True) class PersistenceRetryPolicy: """Bounded retry policy for short run-store writes.""" max_attempts: int = 5 initial_delay: float = 0.05 max_delay: float = 1.0 backoff_factor: float = 2.0 @dataclass class RunRecord: """Mutable record for a single run.""" run_id: str thread_id: str assistant_id: str | None status: RunStatus on_disconnect: DisconnectMode multitask_strategy: str = "reject" metadata: dict = field(default_factory=dict) kwargs: dict = field(default_factory=dict) created_at: str = "" updated_at: str = "" task: asyncio.Task | None = field(default=None, repr=False) abort_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False) abort_action: str = "interrupt" error: str | None = None model_name: str | None = None store_only: bool = False total_input_tokens: int = 0 total_output_tokens: int = 0 total_tokens: int = 0 llm_call_count: int = 0 lead_agent_tokens: int = 0 subagent_tokens: int = 0 middleware_tokens: int = 0 message_count: int = 0 last_ai_message: str | None = None first_human_message: str | None = None class RunManager: """In-memory run registry with optional persistent RunStore backing. All mutations are protected by an asyncio lock. When a ``store`` is provided, serializable metadata is also persisted to the store so that run history survives process restarts. """ def __init__( self, store: RunStore | None = None, *, persistence_retry_policy: PersistenceRetryPolicy | None = None, ) -> None: self._runs: dict[str, RunRecord] = {} self._lock = asyncio.Lock() self._store = store self._persistence_retry_policy = persistence_retry_policy or PersistenceRetryPolicy() @staticmethod def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]: return { "thread_id": record.thread_id, "assistant_id": record.assistant_id, "status": record.status.value, "multitask_strategy": record.multitask_strategy, "metadata": record.metadata or {}, "kwargs": record.kwargs or {}, "error": error if error is not None else record.error, "created_at": record.created_at, "model_name": record.model_name, } async def _call_store_with_retry( self, operation_name: str, run_id: str, operation: Callable[[], Awaitable[Any]], ) -> Any: """Run a short store operation with bounded retries for SQLite pressure.""" policy = self._persistence_retry_policy attempt = 1 delay = policy.initial_delay while True: try: return await operation() except Exception as exc: retryable = _is_retryable_persistence_error(exc) if attempt >= policy.max_attempts or not retryable: raise logger.warning( "Transient persistence failure during %s for run %s (attempt %d/%d); retrying", operation_name, run_id, attempt, policy.max_attempts, exc_info=True, ) if delay > 0: await asyncio.sleep(delay) delay = min(policy.max_delay, delay * policy.backoff_factor if delay else policy.initial_delay) attempt += 1 async def _persist_snapshot_to_store(self, run_id: str, payload: dict[str, Any]) -> bool: """Best-effort persist a previously captured run snapshot.""" if self._store is None: return True try: await self._call_store_with_retry( "put", run_id, lambda: self._store.put(run_id, **payload), ) return True except Exception: logger.warning("Failed to persist run %s to store", run_id, exc_info=True) return False async def _persist_new_run_to_store(self, record: RunRecord) -> None: """Persist a newly created run record to the backing store. Initial run creation is part of the run visibility boundary: callers should not observe a run in memory unless its backing store row exists. Unlike follow-up status/model updates, failures are propagated so the caller can treat creation as failed. Rollback is the caller's responsibility after inserting the record into ``_runs``. """ if self._store is None: return await self._call_store_with_retry( "put", record.run_id, lambda: self._store.put(record.run_id, **self._store_put_payload(record)), ) async def _persist_to_store(self, record: RunRecord, *, error: str | None = None) -> bool: """Best-effort persist run record to backing store.""" return await self._persist_snapshot_to_store( record.run_id, self._store_put_payload(record, error=error), ) async def _persist_status(self, record: RunRecord, status: RunStatus, *, error: str | None = None) -> bool: """Best-effort persist a status transition to the backing store.""" if self._store is None: return True row_recovery_payload = self._store_put_payload(record, error=error) try: updated = await self._call_store_with_retry( "update_status", record.run_id, lambda: self._store.update_status(record.run_id, status.value, error=error), ) if updated is False: return await self._persist_snapshot_to_store(record.run_id, row_recovery_payload) return True except Exception: logger.warning("Failed to persist status update for run %s", record.run_id, exc_info=True) return False @staticmethod def _record_from_store(row: dict[str, Any]) -> RunRecord: """Build a read-only runtime record from a serialized store row. NULL status/on_disconnect columns (e.g. from rows written before those columns were added) default to ``pending`` and ``cancel`` respectively. """ return RunRecord( run_id=row["run_id"], thread_id=row["thread_id"], assistant_id=row.get("assistant_id"), status=RunStatus(row.get("status") or RunStatus.pending.value), on_disconnect=DisconnectMode(row.get("on_disconnect") or DisconnectMode.cancel.value), multitask_strategy=row.get("multitask_strategy") or "reject", metadata=row.get("metadata") or {}, kwargs=row.get("kwargs") or {}, created_at=row.get("created_at") or "", updated_at=row.get("updated_at") or "", error=row.get("error"), model_name=row.get("model_name"), store_only=True, total_input_tokens=row.get("total_input_tokens") or 0, total_output_tokens=row.get("total_output_tokens") or 0, total_tokens=row.get("total_tokens") or 0, llm_call_count=row.get("llm_call_count") or 0, lead_agent_tokens=row.get("lead_agent_tokens") or 0, subagent_tokens=row.get("subagent_tokens") or 0, middleware_tokens=row.get("middleware_tokens") or 0, message_count=row.get("message_count") or 0, last_ai_message=row.get("last_ai_message"), first_human_message=row.get("first_human_message"), ) async def update_run_completion(self, run_id: str, **kwargs) -> None: """Persist token usage and completion data to the backing store.""" row_recovery_payload: dict[str, Any] | None = None async with self._lock: record = self._runs.get(run_id) if record is not None: for key, value in kwargs.items(): if key == "status": continue if hasattr(record, key) and value is not None: setattr(record, key, value) record.updated_at = _now_iso() row_recovery_payload = self._store_put_payload(record, error=kwargs.get("error")) if self._store is None: return try: updated = await self._call_store_with_retry( "update_run_completion", run_id, lambda: self._store.update_run_completion(run_id, **kwargs), ) if updated is False: if row_recovery_payload is None: logger.warning("Failed to recreate missing run %s for completion persistence", run_id) return if not await self._persist_snapshot_to_store(run_id, row_recovery_payload): return recovered = await self._call_store_with_retry( "update_run_completion", run_id, lambda: self._store.update_run_completion(run_id, **kwargs), ) if recovered is False: logger.warning("Run completion update for %s affected no rows after row recreation", run_id) except Exception: logger.warning("Failed to persist run completion for %s", run_id, exc_info=True) async def update_run_progress(self, run_id: str, **kwargs) -> None: """Persist a running token/message snapshot without changing status.""" should_persist = True async with self._lock: record = self._runs.get(run_id) if record is not None: should_persist = record.status == RunStatus.running if record is not None and should_persist: for key, value in kwargs.items(): if hasattr(record, key) and value is not None: setattr(record, key, value) record.updated_at = _now_iso() if should_persist and self._store is not None: try: await self._store.update_run_progress(run_id, **kwargs) except Exception: logger.warning("Failed to persist run progress for %s", run_id, exc_info=True) async def create( self, thread_id: str, assistant_id: str | None = None, *, on_disconnect: DisconnectMode = DisconnectMode.cancel, metadata: dict | None = None, kwargs: dict | None = None, multitask_strategy: str = "reject", ) -> RunRecord: """Create a new pending run and register it.""" run_id = str(uuid.uuid4()) now = _now_iso() record = RunRecord( run_id=run_id, thread_id=thread_id, assistant_id=assistant_id, status=RunStatus.pending, on_disconnect=on_disconnect, multitask_strategy=multitask_strategy, metadata=metadata or {}, kwargs=kwargs or {}, created_at=now, updated_at=now, ) async with self._lock: self._runs[run_id] = record persisted = False try: await self._persist_new_run_to_store(record) persisted = True except Exception: logger.warning("Failed to persist run %s; rolled back in-memory record", run_id, exc_info=True) raise finally: # Also covers cancellation, which bypasses ``except Exception``. if not persisted: self._runs.pop(run_id, None) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) return record async def get(self, run_id: str, *, user_id: str | None = None) -> RunRecord | None: """Return a run record by ID, or ``None``. Args: run_id: The run ID to look up. user_id: Optional user ID for permission filtering when hydrating from store. """ async with self._lock: record = self._runs.get(run_id) if record is not None: return record if self._store is None: return None try: row = await self._store.get(run_id, user_id=user_id) except Exception: logger.warning("Failed to hydrate run %s from store", run_id, exc_info=True) return None # Re-check after store await: a concurrent create() may have inserted the # in-memory record while the store call was in flight. async with self._lock: record = self._runs.get(run_id) if record is not None: return record if row is None: return None try: return self._record_from_store(row) except Exception: logger.warning("Failed to map store row for run %s", run_id, exc_info=True) return None async def aget(self, run_id: str, *, user_id: str | None = None) -> RunRecord | None: """Return a run record by ID, checking the persistent store as fallback. Alias for :meth:`get` for backward compatibility. """ return await self.get(run_id, user_id=user_id) async def list_by_thread(self, thread_id: str, *, user_id: str | None = None, limit: int = 100) -> list[RunRecord]: """Return runs for a given thread, newest first, at most ``limit`` records. In-memory runs take precedence only when the same ``run_id`` exists in both memory and the backing store. The merged result is then sorted newest-first by ``created_at`` and trimmed to ``limit`` (default 100). Args: thread_id: The thread ID to filter by. user_id: Optional user ID for permission filtering when hydrating from store. limit: Maximum number of runs to return. """ async with self._lock: # Dict insertion order gives deterministic results when timestamps tie. memory_records = [r for r in self._runs.values() if r.thread_id == thread_id] if self._store is None: return sorted(memory_records, key=lambda r: r.created_at, reverse=True)[:limit] records_by_id = {record.run_id: record for record in memory_records} store_limit = max(0, limit - len(memory_records)) try: rows = await self._store.list_by_thread(thread_id, user_id=user_id, limit=store_limit) except Exception: logger.warning("Failed to hydrate runs for thread %s from store", thread_id, exc_info=True) return sorted(memory_records, key=lambda r: r.created_at, reverse=True)[:limit] for row in rows: run_id = row.get("run_id") if run_id and run_id not in records_by_id: try: records_by_id[run_id] = self._record_from_store(row) except Exception: logger.warning("Failed to map store row for run %s", run_id, exc_info=True) return sorted(records_by_id.values(), key=lambda record: record.created_at, reverse=True)[:limit] async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None: """Transition a run to a new status.""" async with self._lock: record = self._runs.get(run_id) if record is None: logger.warning("set_status called for unknown run %s", run_id) return record.status = status record.updated_at = _now_iso() if error is not None: record.error = error await self._persist_status(record, status, error=error) logger.info("Run %s -> %s", run_id, status.value) async def _persist_model_name(self, run_id: str, model_name: str | None) -> None: """Best-effort persist model_name update to the backing store.""" if self._store is None: return try: await self._call_store_with_retry( "update_model_name", run_id, lambda: self._store.update_model_name(run_id, model_name), ) except Exception: logger.warning("Failed to persist model_name update for run %s", run_id, exc_info=True) async def update_model_name(self, run_id: str, model_name: str | None) -> None: """Update the model name for a run.""" async with self._lock: record = self._runs.get(run_id) if record is None: logger.warning("update_model_name called for unknown run %s", run_id) return record.model_name = model_name record.updated_at = _now_iso() await self._persist_model_name(run_id, model_name) logger.info("Run %s model_name=%s", run_id, model_name) async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool: """Request cancellation of a run. Args: run_id: The run ID to cancel. action: "interrupt" keeps checkpoint, "rollback" reverts to pre-run state. Sets the abort event with the action reason and cancels the asyncio task. Returns ``True`` if cancellation was initiated **or** the run was already interrupted (idempotent — a second cancel is a no-op success). Returns ``False`` only when the run is unknown to this worker or has reached a terminal state other than interrupted (completed, failed, etc.). """ async with self._lock: record = self._runs.get(run_id) if record is None: return False if record.status == RunStatus.interrupted: return True # idempotent — already cancelled on this worker if record.status not in (RunStatus.pending, RunStatus.running): return False record.abort_action = action record.abort_event.set() if record.task is not None and not record.task.done(): record.task.cancel() record.status = RunStatus.interrupted record.updated_at = _now_iso() await self._persist_status(record, RunStatus.interrupted) logger.info("Run %s cancelled (action=%s)", run_id, action) return True async def create_or_reject( self, thread_id: str, assistant_id: str | None = None, *, on_disconnect: DisconnectMode = DisconnectMode.cancel, metadata: dict | None = None, kwargs: dict | None = None, multitask_strategy: str = "reject", model_name: str | None = None, ) -> RunRecord: """Atomically check for inflight runs and create a new one. For ``reject`` strategy, raises ``ConflictError`` if thread already has a pending/running run. For ``interrupt``/``rollback``, cancels inflight runs before creating. This method holds the lock across both the check and the insert, eliminating the TOCTOU race in separate ``has_inflight`` + ``create``. """ run_id = str(uuid.uuid4()) now = _now_iso() _supported_strategies = ("reject", "interrupt", "rollback") interrupted_records: list[RunRecord] = [] async with self._lock: if multitask_strategy not in _supported_strategies: raise UnsupportedStrategyError(f"Multitask strategy '{multitask_strategy}' is not yet supported. Supported strategies: {', '.join(_supported_strategies)}") inflight = [r for r in self._runs.values() if r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running)] if multitask_strategy == "reject" and inflight: raise ConflictError(f"Thread {thread_id} already has an active run") if multitask_strategy in ("interrupt", "rollback") and inflight: logger.info( "Preparing to cancel %d inflight run(s) on thread %s (strategy=%s)", len(inflight), thread_id, multitask_strategy, ) record = RunRecord( run_id=run_id, thread_id=thread_id, assistant_id=assistant_id, status=RunStatus.pending, on_disconnect=on_disconnect, multitask_strategy=multitask_strategy, metadata=metadata or {}, kwargs=kwargs or {}, created_at=now, updated_at=now, model_name=model_name, ) self._runs[run_id] = record persisted = False try: await self._persist_new_run_to_store(record) persisted = True except Exception: logger.warning("Failed to persist run %s; rolled back in-memory record", run_id, exc_info=True) raise finally: # Also covers cancellation, which bypasses ``except Exception``. if not persisted: self._runs.pop(run_id, None) if multitask_strategy in ("interrupt", "rollback") and inflight: for r in inflight: r.abort_action = multitask_strategy r.abort_event.set() if r.task is not None and not r.task.done(): r.task.cancel() r.status = RunStatus.interrupted r.updated_at = now interrupted_records.append(r) for interrupted_record in interrupted_records: await self._persist_status(interrupted_record, RunStatus.interrupted) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) return record async def reconcile_orphaned_inflight_runs( self, *, error: str, before: str | None = None, ) -> list[RunRecord]: """Mark persisted active runs as failed when no local task owns them. Gateway runs are process-local: the asyncio task and abort event live in memory, while the run row is durable. After a SQLite-backed gateway restart, any persisted ``pending`` or ``running`` row created before startup cannot still have a local worker. This recovery step turns that ambiguous state into an explicit error instead of letting the UI show an indefinite active run. """ if self._store is None: return [] try: rows = await self._call_store_with_retry( "list_inflight", "*", lambda: self._store.list_inflight(before=before), ) except Exception: logger.warning("Failed to list orphaned inflight runs for reconciliation", exc_info=True) return [] recovered: list[RunRecord] = [] now = _now_iso() for row in rows: try: record = self._record_from_store(row) except Exception: logger.warning("Failed to map orphaned run row during reconciliation", exc_info=True) continue async with self._lock: live_record = self._runs.get(record.run_id) if live_record is not None and live_record.status in (RunStatus.pending, RunStatus.running): continue record.status = RunStatus.error record.error = error record.updated_at = now persisted = await self._persist_status(record, RunStatus.error, error=error) if not persisted: logger.warning("Skipped orphaned run %s recovery because error status was not persisted", record.run_id) continue recovered.append(record) if recovered: logger.warning("Recovered %d orphaned inflight run(s) as error", len(recovered)) return recovered async def has_inflight(self, thread_id: str) -> bool: """Return ``True`` if *thread_id* has a pending or running run.""" async with self._lock: return any(r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running) for r in self._runs.values()) async def cleanup(self, run_id: str, *, delay: float = 300) -> None: """Remove a run record after an optional delay.""" if delay > 0: await asyncio.sleep(delay) async with self._lock: self._runs.pop(run_id, None) logger.debug("Run record %s cleaned up", run_id) class ConflictError(Exception): """Raised when multitask_strategy=reject and thread has inflight runs.""" class UnsupportedStrategyError(Exception): """Raised when a multitask_strategy value is not yet implemented."""