"""In-memory run registry with optional persistent RunStore backing.""" from __future__ import annotations import asyncio import logging import uuid from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any from deerflow.utils.time import now_iso as _now_iso from .schemas import DisconnectMode, RunStatus if TYPE_CHECKING: from deerflow.runtime.runs.store.base import RunStore logger = logging.getLogger(__name__) @dataclass class RunRecord: """Mutable record for a single run.""" run_id: str thread_id: str assistant_id: str | None status: RunStatus on_disconnect: DisconnectMode multitask_strategy: str = "reject" metadata: dict = field(default_factory=dict) kwargs: dict = field(default_factory=dict) created_at: str = "" updated_at: str = "" task: asyncio.Task | None = field(default=None, repr=False) abort_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False) abort_action: str = "interrupt" error: str | None = None model_name: str | None = None store_only: bool = False class RunManager: """In-memory run registry with optional persistent RunStore backing. All mutations are protected by an asyncio lock. When a ``store`` is provided, serializable metadata is also persisted to the store so that run history survives process restarts. """ def __init__(self, store: RunStore | None = None) -> None: self._runs: dict[str, RunRecord] = {} self._lock = asyncio.Lock() self._store = store async def _persist_to_store(self, record: RunRecord) -> None: """Best-effort persist run record to backing store.""" if self._store is None: return try: await self._store.put( record.run_id, thread_id=record.thread_id, assistant_id=record.assistant_id, status=record.status.value, multitask_strategy=record.multitask_strategy, metadata=record.metadata or {}, kwargs=record.kwargs or {}, created_at=record.created_at, model_name=record.model_name, ) except Exception: logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True) async def _persist_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None: """Best-effort persist a status transition to the backing store.""" if self._store is None: return try: await self._store.update_status(run_id, status.value, error=error) except Exception: logger.warning("Failed to persist status update for run %s", run_id, exc_info=True) @staticmethod def _record_from_store(row: dict[str, Any]) -> RunRecord: """Build a read-only runtime record from a serialized store row. NULL status/on_disconnect columns (e.g. from rows written before those columns were added) default to ``pending`` and ``cancel`` respectively. """ return RunRecord( run_id=row["run_id"], thread_id=row["thread_id"], assistant_id=row.get("assistant_id"), status=RunStatus(row.get("status") or RunStatus.pending.value), on_disconnect=DisconnectMode(row.get("on_disconnect") or DisconnectMode.cancel.value), multitask_strategy=row.get("multitask_strategy") or "reject", metadata=row.get("metadata") or {}, kwargs=row.get("kwargs") or {}, created_at=row.get("created_at") or "", updated_at=row.get("updated_at") or "", error=row.get("error"), model_name=row.get("model_name"), store_only=True, ) async def update_run_completion(self, run_id: str, **kwargs) -> None: """Persist token usage and completion data to the backing store.""" if self._store is not None: try: await self._store.update_run_completion(run_id, **kwargs) except Exception: logger.warning("Failed to persist run completion for %s", run_id, exc_info=True) async def create( self, thread_id: str, assistant_id: str | None = None, *, on_disconnect: DisconnectMode = DisconnectMode.cancel, metadata: dict | None = None, kwargs: dict | None = None, multitask_strategy: str = "reject", ) -> RunRecord: """Create a new pending run and register it.""" run_id = str(uuid.uuid4()) now = _now_iso() record = RunRecord( run_id=run_id, thread_id=thread_id, assistant_id=assistant_id, status=RunStatus.pending, on_disconnect=on_disconnect, multitask_strategy=multitask_strategy, metadata=metadata or {}, kwargs=kwargs or {}, created_at=now, updated_at=now, ) async with self._lock: self._runs[run_id] = record await self._persist_to_store(record) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) return record async def get(self, run_id: str, *, user_id: str | None = None) -> RunRecord | None: """Return a run record by ID, or ``None``. Args: run_id: The run ID to look up. user_id: Optional user ID for permission filtering when hydrating from store. """ async with self._lock: record = self._runs.get(run_id) if record is not None: return record if self._store is None: return None try: row = await self._store.get(run_id, user_id=user_id) except Exception: logger.warning("Failed to hydrate run %s from store", run_id, exc_info=True) return None # Re-check after store await: a concurrent create() may have inserted the # in-memory record while the store call was in flight. async with self._lock: record = self._runs.get(run_id) if record is not None: return record if row is None: return None try: return self._record_from_store(row) except Exception: logger.warning("Failed to map store row for run %s", run_id, exc_info=True) return None async def aget(self, run_id: str, *, user_id: str | None = None) -> RunRecord | None: """Return a run record by ID, checking the persistent store as fallback. Alias for :meth:`get` for backward compatibility. """ return await self.get(run_id, user_id=user_id) async def list_by_thread(self, thread_id: str, *, user_id: str | None = None, limit: int = 100) -> list[RunRecord]: """Return runs for a given thread, newest first, at most ``limit`` records. In-memory runs take precedence only when the same ``run_id`` exists in both memory and the backing store. The merged result is then sorted newest-first by ``created_at`` and trimmed to ``limit`` (default 100). Args: thread_id: The thread ID to filter by. user_id: Optional user ID for permission filtering when hydrating from store. limit: Maximum number of runs to return. """ async with self._lock: # Dict insertion order gives deterministic results when timestamps tie. memory_records = [r for r in self._runs.values() if r.thread_id == thread_id] if self._store is None: return sorted(memory_records, key=lambda r: r.created_at, reverse=True)[:limit] records_by_id = {record.run_id: record for record in memory_records} store_limit = max(0, limit - len(memory_records)) try: rows = await self._store.list_by_thread(thread_id, user_id=user_id, limit=store_limit) except Exception: logger.warning("Failed to hydrate runs for thread %s from store", thread_id, exc_info=True) return sorted(memory_records, key=lambda r: r.created_at, reverse=True)[:limit] for row in rows: run_id = row.get("run_id") if run_id and run_id not in records_by_id: try: records_by_id[run_id] = self._record_from_store(row) except Exception: logger.warning("Failed to map store row for run %s", run_id, exc_info=True) return sorted(records_by_id.values(), key=lambda record: record.created_at, reverse=True)[:limit] async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None: """Transition a run to a new status.""" async with self._lock: record = self._runs.get(run_id) if record is None: logger.warning("set_status called for unknown run %s", run_id) return record.status = status record.updated_at = _now_iso() if error is not None: record.error = error await self._persist_status(run_id, status, error=error) logger.info("Run %s -> %s", run_id, status.value) async def _persist_model_name(self, run_id: str, model_name: str | None) -> None: """Best-effort persist model_name update to the backing store.""" if self._store is None: return try: await self._store.update_model_name(run_id, model_name) except Exception: logger.warning("Failed to persist model_name update for run %s", run_id, exc_info=True) async def update_model_name(self, run_id: str, model_name: str | None) -> None: """Update the model name for a run.""" async with self._lock: record = self._runs.get(run_id) if record is None: logger.warning("update_model_name called for unknown run %s", run_id) return record.model_name = model_name record.updated_at = _now_iso() await self._persist_model_name(run_id, model_name) logger.info("Run %s model_name=%s", run_id, model_name) async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool: """Request cancellation of a run. Args: run_id: The run ID to cancel. action: "interrupt" keeps checkpoint, "rollback" reverts to pre-run state. Sets the abort event with the action reason and cancels the asyncio task. Returns ``True`` if the run was in-flight and cancellation was initiated. """ async with self._lock: record = self._runs.get(run_id) if record is None: return False if record.status not in (RunStatus.pending, RunStatus.running): return False record.abort_action = action record.abort_event.set() if record.task is not None and not record.task.done(): record.task.cancel() record.status = RunStatus.interrupted record.updated_at = _now_iso() await self._persist_status(run_id, RunStatus.interrupted) logger.info("Run %s cancelled (action=%s)", run_id, action) return True async def create_or_reject( self, thread_id: str, assistant_id: str | None = None, *, on_disconnect: DisconnectMode = DisconnectMode.cancel, metadata: dict | None = None, kwargs: dict | None = None, multitask_strategy: str = "reject", model_name: str | None = None, ) -> RunRecord: """Atomically check for inflight runs and create a new one. For ``reject`` strategy, raises ``ConflictError`` if thread already has a pending/running run. For ``interrupt``/``rollback``, cancels inflight runs before creating. This method holds the lock across both the check and the insert, eliminating the TOCTOU race in separate ``has_inflight`` + ``create``. """ run_id = str(uuid.uuid4()) now = _now_iso() _supported_strategies = ("reject", "interrupt", "rollback") interrupted_run_ids: list[str] = [] async with self._lock: if multitask_strategy not in _supported_strategies: raise UnsupportedStrategyError(f"Multitask strategy '{multitask_strategy}' is not yet supported. Supported strategies: {', '.join(_supported_strategies)}") inflight = [r for r in self._runs.values() if r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running)] if multitask_strategy == "reject" and inflight: raise ConflictError(f"Thread {thread_id} already has an active run") if multitask_strategy in ("interrupt", "rollback") and inflight: for r in inflight: r.abort_action = multitask_strategy r.abort_event.set() if r.task is not None and not r.task.done(): r.task.cancel() r.status = RunStatus.interrupted r.updated_at = now interrupted_run_ids.append(r.run_id) logger.info( "Cancelled %d inflight run(s) on thread %s (strategy=%s)", len(inflight), thread_id, multitask_strategy, ) record = RunRecord( run_id=run_id, thread_id=thread_id, assistant_id=assistant_id, status=RunStatus.pending, on_disconnect=on_disconnect, multitask_strategy=multitask_strategy, metadata=metadata or {}, kwargs=kwargs or {}, created_at=now, updated_at=now, model_name=model_name, ) self._runs[run_id] = record for interrupted_run_id in interrupted_run_ids: await self._persist_status(interrupted_run_id, RunStatus.interrupted) await self._persist_to_store(record) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) return record async def has_inflight(self, thread_id: str) -> bool: """Return ``True`` if *thread_id* has a pending or running run.""" async with self._lock: return any(r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running) for r in self._runs.values()) async def cleanup(self, run_id: str, *, delay: float = 300) -> None: """Remove a run record after an optional delay.""" if delay > 0: await asyncio.sleep(delay) async with self._lock: self._runs.pop(run_id, None) logger.debug("Run record %s cleaned up", run_id) class ConflictError(Exception): """Raised when multitask_strategy=reject and thread has inflight runs.""" class UnsupportedStrategyError(Exception): """Raised when a multitask_strategy value is not yet implemented."""