"""In-memory run registry with optional persistent RunStore backing.""" from __future__ import annotations import asyncio import logging import sqlite3 import uuid from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any from deerflow.utils.time import now_iso as _now_iso from .schemas import DisconnectMode, RunStatus if TYPE_CHECKING: from deerflow.runtime.runs.store.base import RunStore logger = logging.getLogger(__name__) _RETRYABLE_SQLITE_MESSAGES = ( "database is locked", "database table is locked", "database is busy", ) _RETRYABLE_SQLITE_ERROR_CODES = { sqlite3.SQLITE_BUSY, sqlite3.SQLITE_LOCKED, } def _is_retryable_persistence_error(exc: BaseException) -> bool: """Return True for transient SQLite persistence failures. SQLite lock contention normally surfaces through either sqlite3 exceptions or SQLAlchemy wrappers. The short bounded retry here protects run status finalization from transient writer pressure without hiding permanent failures forever. """ pending: list[BaseException] = [exc] seen: set[int] = set() while pending: current = pending.pop() if id(current) in seen: continue seen.add(id(current)) message = str(current).lower() if any(fragment in message for fragment in _RETRYABLE_SQLITE_MESSAGES): return True if isinstance(current, (sqlite3.OperationalError, sqlite3.DatabaseError)): error_code = getattr(current, "sqlite_errorcode", None) if error_code in _RETRYABLE_SQLITE_ERROR_CODES: return True for chained in (getattr(current, "orig", None), current.__cause__, current.__context__): if isinstance(chained, BaseException): pending.append(chained) return False @dataclass(frozen=True) class PersistenceRetryPolicy: """Bounded retry policy for short run-store writes.""" max_attempts: int = 5 initial_delay: float = 0.05 max_delay: float = 1.0 backoff_factor: float = 2.0 @dataclass class RunRecord: """Mutable record for a single run.""" run_id: str thread_id: str assistant_id: str | None status: RunStatus on_disconnect: DisconnectMode multitask_strategy: str = "reject" metadata: dict = field(default_factory=dict) kwargs: dict = field(default_factory=dict) created_at: str = "" updated_at: str = "" task: asyncio.Task | None = field(default=None, repr=False) abort_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False) abort_action: str = "interrupt" error: str | None = None model_name: str | None = None store_only: bool = False total_input_tokens: int = 0 total_output_tokens: int = 0 total_tokens: int = 0 llm_call_count: int = 0 lead_agent_tokens: int = 0 subagent_tokens: int = 0 middleware_tokens: int = 0 message_count: int = 0 last_ai_message: str | None = None first_human_message: str | None = None class RunManager: """In-memory run registry with optional persistent RunStore backing. All mutations are protected by an asyncio lock. When a ``store`` is provided, serializable metadata is also persisted to the store so that run history survives process restarts. """ def __init__( self, store: RunStore | None = None, *, persistence_retry_policy: PersistenceRetryPolicy | None = None, ) -> None: self._runs: dict[str, RunRecord] = {} self._lock = asyncio.Lock() self._store = store self._persistence_retry_policy = persistence_retry_policy or PersistenceRetryPolicy() @staticmethod def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]: return { "thread_id": record.thread_id, "assistant_id": record.assistant_id, "status": record.status.value, "multitask_strategy": record.multitask_strategy, "metadata": record.metadata or {}, "kwargs": record.kwargs or {}, "error": error if error is not None else record.error, "created_at": record.created_at, "model_name": record.model_name, } async def _call_store_with_retry( self, operation_name: str, run_id: str, operation: Callable[[], Awaitable[Any]], ) -> Any: """Run a short store operation with bounded retries for SQLite pressure.""" policy = self._persistence_retry_policy attempt = 1 delay = policy.initial_delay while True: try: return await operation() except Exception as exc: retryable = _is_retryable_persistence_error(exc) if attempt >= policy.max_attempts or not retryable: raise logger.warning( "Transient persistence failure during %s for run %s (attempt %d/%d); retrying", operation_name, run_id, attempt, policy.max_attempts, exc_info=True, ) if delay > 0: await asyncio.sleep(delay) delay = min(policy.max_delay, delay * policy.backoff_factor if delay else policy.initial_delay) attempt += 1 async def _persist_snapshot_to_store(self, run_id: str, payload: dict[str, Any]) -> bool: """Best-effort persist a previously captured run snapshot.""" if self._store is None: return True try: await self._call_store_with_retry( "put", run_id, lambda: self._store.put(run_id, **payload), ) return True except Exception: logger.warning("Failed to persist run %s to store", run_id, exc_info=True) return False async def _persist_new_run_to_store(self, record: RunRecord) -> None: """Persist a newly created run record to the backing store. Initial run creation is part of the run visibility boundary: callers should not observe a run in memory unless its backing store row exists. Unlike follow-up status/model updates, failures are propagated so the caller can treat creation as failed. Rollback is the caller's responsibility after inserting the record into ``_runs``. """ if self._store is None: return await self._call_store_with_retry( "put", record.run_id, lambda: self._store.put(record.run_id, **self._store_put_payload(record)), ) async def _persist_to_store(self, record: RunRecord, *, error: str | None = None) -> bool: """Best-effort persist run record to backing store.""" return await self._persist_snapshot_to_store( record.run_id, self._store_put_payload(record, error=error), ) async def _persist_status(self, record: RunRecord, status: RunStatus, *, error: str | None = None) -> bool: """Best-effort persist a status transition to the backing store.""" if self._store is None: return True row_recovery_payload = self._store_put_payload(record, error=error) try: updated = await self._call_store_with_retry( "update_status", record.run_id, lambda: self._store.update_status(record.run_id, status.value, error=error), ) if updated is False: return await self._persist_snapshot_to_store(record.run_id, row_recovery_payload) return True except Exception: logger.warning("Failed to persist status update for run %s", record.run_id, exc_info=True) return False @staticmethod def _record_from_store(row: dict[str, Any]) -> RunRecord: """Build a read-only runtime record from a serialized store row. NULL status/on_disconnect columns (e.g. from rows written before those columns were added) default to ``pending`` and ``cancel`` respectively. """ return RunRecord( run_id=row["run_id"], thread_id=row["thread_id"], assistant_id=row.get("assistant_id"), status=RunStatus(row.get("status") or RunStatus.pending.value), on_disconnect=DisconnectMode(row.get("on_disconnect") or DisconnectMode.cancel.value), multitask_strategy=row.get("multitask_strategy") or "reject", metadata=row.get("metadata") or {}, kwargs=row.get("kwargs") or {}, created_at=row.get("created_at") or "", updated_at=row.get("updated_at") or "", error=row.get("error"), model_name=row.get("model_name"), store_only=True, total_input_tokens=row.get("total_input_tokens") or 0, total_output_tokens=row.get("total_output_tokens") or 0, total_tokens=row.get("total_tokens") or 0, llm_call_count=row.get("llm_call_count") or 0, lead_agent_tokens=row.get("lead_agent_tokens") or 0, subagent_tokens=row.get("subagent_tokens") or 0, middleware_tokens=row.get("middleware_tokens") or 0, message_count=row.get("message_count") or 0, last_ai_message=row.get("last_ai_message"), first_human_message=row.get("first_human_message"), ) async def update_run_completion(self, run_id: str, **kwargs) -> None: """Persist token usage and completion data to the backing store.""" row_recovery_payload: dict[str, Any] | None = None async with self._lock: record = self._runs.get(run_id) if record is not None: for key, value in kwargs.items(): if key == "status": continue if hasattr(record, key) and value is not None: setattr(record, key, value) record.updated_at = _now_iso() row_recovery_payload = self._store_put_payload(record, error=kwargs.get("error")) if self._store is None: return try: updated = await self._call_store_with_retry( "update_run_completion", run_id, lambda: self._store.update_run_completion(run_id, **kwargs), ) if updated is False: if row_recovery_payload is None: logger.warning("Failed to recreate missing run %s for completion persistence", run_id) return if not await self._persist_snapshot_to_store(run_id, row_recovery_payload): return recovered = await self._call_store_with_retry( "update_run_completion", run_id, lambda: self._store.update_run_completion(run_id, **kwargs), ) if recovered is False: logger.warning("Run completion update for %s affected no rows after row recreation", run_id) except Exception: logger.warning("Failed to persist run completion for %s", run_id, exc_info=True) async def update_run_progress(self, run_id: str, **kwargs) -> None: """Persist a running token/message snapshot without changing status.""" should_persist = True async with self._lock: record = self._runs.get(run_id) if record is not None: should_persist = record.status == RunStatus.running if record is not None and should_persist: for key, value in kwargs.items(): if hasattr(record, key) and value is not None: setattr(record, key, value) record.updated_at = _now_iso() if should_persist and self._store is not None: try: await self._store.update_run_progress(run_id, **kwargs) except Exception: logger.warning("Failed to persist run progress for %s", run_id, exc_info=True) async def create( self, thread_id: str, assistant_id: str | None = None, *, on_disconnect: DisconnectMode = DisconnectMode.cancel, metadata: dict | None = None, kwargs: dict | None = None, multitask_strategy: str = "reject", ) -> RunRecord: """Create a new pending run and register it.""" run_id = str(uuid.uuid4()) now = _now_iso() record = RunRecord( run_id=run_id, thread_id=thread_id, assistant_id=assistant_id, status=RunStatus.pending, on_disconnect=on_disconnect, multitask_strategy=multitask_strategy, metadata=metadata or {}, kwargs=kwargs or {}, created_at=now, updated_at=now, ) async with self._lock: self._runs[run_id] = record persisted = False try: await self._persist_new_run_to_store(record) persisted = True except Exception: logger.warning("Failed to persist run %s; rolled back in-memory record", run_id, exc_info=True) raise finally: # Also covers cancellation, which bypasses ``except Exception``. if not persisted: self._runs.pop(run_id, None) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) return record async def get(self, run_id: str, *, user_id: str | None = None) -> RunRecord | None: """Return a run record by ID, or ``None``. Args: run_id: The run ID to look up. user_id: Optional user ID for permission filtering when hydrating from store. """ async with self._lock: record = self._runs.get(run_id) if record is not None: return record if self._store is None: return None try: row = await self._store.get(run_id, user_id=user_id) except Exception: logger.warning("Failed to hydrate run %s from store", run_id, exc_info=True) return None # Re-check after store await: a concurrent create() may have inserted the # in-memory record while the store call was in flight. async with self._lock: record = self._runs.get(run_id) if record is not None: return record if row is None: return None try: return self._record_from_store(row) except Exception: logger.warning("Failed to map store row for run %s", run_id, exc_info=True) return None async def aget(self, run_id: str, *, user_id: str | None = None) -> RunRecord | None: """Return a run record by ID, checking the persistent store as fallback. Alias for :meth:`get` for backward compatibility. """ return await self.get(run_id, user_id=user_id) async def list_by_thread(self, thread_id: str, *, user_id: str | None = None, limit: int = 100) -> list[RunRecord]: """Return runs for a given thread, newest first, at most ``limit`` records. In-memory runs take precedence only when the same ``run_id`` exists in both memory and the backing store. The merged result is then sorted newest-first by ``created_at`` and trimmed to ``limit`` (default 100). Args: thread_id: The thread ID to filter by. user_id: Optional user ID for permission filtering when hydrating from store. limit: Maximum number of runs to return. """ async with self._lock: # Dict insertion order gives deterministic results when timestamps tie. memory_records = [r for r in self._runs.values() if r.thread_id == thread_id] if self._store is None: return sorted(memory_records, key=lambda r: r.created_at, reverse=True)[:limit] records_by_id = {record.run_id: record for record in memory_records} store_limit = max(0, limit - len(memory_records)) try: rows = await self._store.list_by_thread(thread_id, user_id=user_id, limit=store_limit) except Exception: logger.warning("Failed to hydrate runs for thread %s from store", thread_id, exc_info=True) return sorted(memory_records, key=lambda r: r.created_at, reverse=True)[:limit] for row in rows: run_id = row.get("run_id") if run_id and run_id not in records_by_id: try: records_by_id[run_id] = self._record_from_store(row) except Exception: logger.warning("Failed to map store row for run %s", run_id, exc_info=True) return sorted(records_by_id.values(), key=lambda record: record.created_at, reverse=True)[:limit] async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None: """Transition a run to a new status.""" async with self._lock: record = self._runs.get(run_id) if record is None: logger.warning("set_status called for unknown run %s", run_id) return record.status = status record.updated_at = _now_iso() if error is not None: record.error = error await self._persist_status(record, status, error=error) logger.info("Run %s -> %s", run_id, status.value) async def _persist_model_name(self, run_id: str, model_name: str | None) -> None: """Best-effort persist model_name update to the backing store.""" if self._store is None: return try: await self._call_store_with_retry( "update_model_name", run_id, lambda: self._store.update_model_name(run_id, model_name), ) except Exception: logger.warning("Failed to persist model_name update for run %s", run_id, exc_info=True) async def update_model_name(self, run_id: str, model_name: str | None) -> None: """Update the model name for a run.""" async with self._lock: record = self._runs.get(run_id) if record is None: logger.warning("update_model_name called for unknown run %s", run_id) return record.model_name = model_name record.updated_at = _now_iso() await self._persist_model_name(run_id, model_name) logger.info("Run %s model_name=%s", run_id, model_name) async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool: """Request cancellation of a run. Args: run_id: The run ID to cancel. action: "interrupt" keeps checkpoint, "rollback" reverts to pre-run state. Sets the abort event with the action reason and cancels the asyncio task. Returns ``True`` if cancellation was initiated **or** the run was already interrupted (idempotent — a second cancel is a no-op success). Returns ``False`` only when the run is unknown to this worker or has reached a terminal state other than interrupted (completed, failed, etc.). """ async with self._lock: record = self._runs.get(run_id) if record is None: return False if record.status == RunStatus.interrupted: return True # idempotent — already cancelled on this worker if record.status not in (RunStatus.pending, RunStatus.running): return False record.abort_action = action record.abort_event.set() if record.task is not None and not record.task.done(): record.task.cancel() record.status = RunStatus.interrupted record.updated_at = _now_iso() await self._persist_status(record, RunStatus.interrupted) logger.info("Run %s cancelled (action=%s)", run_id, action) return True async def create_or_reject( self, thread_id: str, assistant_id: str | None = None, *, on_disconnect: DisconnectMode = DisconnectMode.cancel, metadata: dict | None = None, kwargs: dict | None = None, multitask_strategy: str = "reject", model_name: str | None = None, ) -> RunRecord: """Atomically check for inflight runs and create a new one. For ``reject`` strategy, raises ``ConflictError`` if thread already has a pending/running run. For ``interrupt``/``rollback``, cancels inflight runs before creating. This method holds the lock across both the check and the insert, eliminating the TOCTOU race in separate ``has_inflight`` + ``create``. """ run_id = str(uuid.uuid4()) now = _now_iso() _supported_strategies = ("reject", "interrupt", "rollback") interrupted_records: list[RunRecord] = [] async with self._lock: if multitask_strategy not in _supported_strategies: raise UnsupportedStrategyError(f"Multitask strategy '{multitask_strategy}' is not yet supported. Supported strategies: {', '.join(_supported_strategies)}") inflight = [r for r in self._runs.values() if r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running)] if multitask_strategy == "reject" and inflight: raise ConflictError(f"Thread {thread_id} already has an active run") if multitask_strategy in ("interrupt", "rollback") and inflight: logger.info( "Preparing to cancel %d inflight run(s) on thread %s (strategy=%s)", len(inflight), thread_id, multitask_strategy, ) record = RunRecord( run_id=run_id, thread_id=thread_id, assistant_id=assistant_id, status=RunStatus.pending, on_disconnect=on_disconnect, multitask_strategy=multitask_strategy, metadata=metadata or {}, kwargs=kwargs or {}, created_at=now, updated_at=now, model_name=model_name, ) self._runs[run_id] = record persisted = False try: await self._persist_new_run_to_store(record) persisted = True except Exception: logger.warning("Failed to persist run %s; rolled back in-memory record", run_id, exc_info=True) raise finally: # Also covers cancellation, which bypasses ``except Exception``. if not persisted: self._runs.pop(run_id, None) if multitask_strategy in ("interrupt", "rollback") and inflight: for r in inflight: r.abort_action = multitask_strategy r.abort_event.set() if r.task is not None and not r.task.done(): r.task.cancel() r.status = RunStatus.interrupted r.updated_at = now interrupted_records.append(r) for interrupted_record in interrupted_records: await self._persist_status(interrupted_record, RunStatus.interrupted) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) return record async def reconcile_orphaned_inflight_runs( self, *, error: str, before: str | None = None, ) -> list[RunRecord]: """Mark persisted active runs as failed when no local task owns them. Gateway runs are process-local: the asyncio task and abort event live in memory, while the run row is durable. After a SQLite-backed gateway restart, any persisted ``pending`` or ``running`` row created before startup cannot still have a local worker. This recovery step turns that ambiguous state into an explicit error instead of letting the UI show an indefinite active run. """ if self._store is None: return [] try: rows = await self._call_store_with_retry( "list_inflight", "*", lambda: self._store.list_inflight(before=before), ) except Exception: logger.warning("Failed to list orphaned inflight runs for reconciliation", exc_info=True) return [] recovered: list[RunRecord] = [] now = _now_iso() for row in rows: try: record = self._record_from_store(row) except Exception: logger.warning("Failed to map orphaned run row during reconciliation", exc_info=True) continue async with self._lock: live_record = self._runs.get(record.run_id) if live_record is not None and live_record.status in (RunStatus.pending, RunStatus.running): continue record.status = RunStatus.error record.error = error record.updated_at = now persisted = await self._persist_status(record, RunStatus.error, error=error) if not persisted: logger.warning("Skipped orphaned run %s recovery because error status was not persisted", record.run_id) continue recovered.append(record) if recovered: logger.warning("Recovered %d orphaned inflight run(s) as error", len(recovered)) return recovered async def has_inflight(self, thread_id: str) -> bool: """Return ``True`` if *thread_id* has a pending or running run.""" async with self._lock: return any(r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running) for r in self._runs.values()) async def cleanup(self, run_id: str, *, delay: float = 300) -> None: """Remove a run record after an optional delay.""" if delay > 0: await asyncio.sleep(delay) async with self._lock: self._runs.pop(run_id, None) logger.debug("Run record %s cleaned up", run_id) async def shutdown(self, *, timeout: float = 5.0) -> None: """Cancel and bounded-await all in-flight runs on process shutdown. Chat runs execute in fire-and-forget background ``asyncio`` tasks that write checkpoints through a shared checkpointer. On shutdown the checkpointer's resources (e.g. the postgres connection pool owned by the gateway's ``AsyncExitStack``) are torn down; if a run task is still mid-graph at that point, langgraph's ``AsyncPregelLoop._checkpointer_put_after_previous`` runs its ``finally: await checkpointer.aput(...)`` against the closed pool. Because that put runs in a langgraph-internal task (not on ``run_agent``'s call stack), the resulting ``psycopg_pool.PoolClosed`` is not catchable by the worker and surfaces as an unhandled exception during ``asyncio.run()`` shutdown (bytedance/deer-flow issue #3373). Draining in-flight runs *before* the checkpointer is closed lets each run that settles within ``timeout`` flush its final checkpoint while resources are still open. Only runs that do **not** settle on their own are marked ``interrupted`` — a run that completes (e.g. ``success``) during the drain keeps its real terminal status instead of being blanket-overwritten. The whole drain, including the trailing status persistence, is bounded by ``timeout`` so a run stuck in cleanup (or a slow store under DB pressure) cannot hang worker shutdown — the precondition for the signal-reentrancy deadlock guarded by ``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS``. Runs still active after ``timeout`` are logged and may still race teardown. """ loop = asyncio.get_running_loop() deadline = loop.time() + timeout async with self._lock: inflight = [record for record in self._runs.values() if record.status in (RunStatus.pending, RunStatus.running) and record.task is not None and not record.task.done()] for record in inflight: record.abort_action = "interrupt" record.abort_event.set() record.task.cancel() # type: ignore[union-attr] # filtered above # Status is decided AFTER the drain (below), not here: a run that # completes on its own during the drain must keep its real status. if not inflight: return tasks = [record.task for record in inflight] _, pending = await asyncio.wait(tasks, timeout=timeout) # Only mark/persist ``interrupted`` for runs that did not settle on their # own (still pending after the timeout, or ended cancelled). A run that # finished normally during the drain keeps the status it set for itself. to_persist: list[RunRecord] = [] async with self._lock: for record in inflight: task = record.task if task not in pending and not task.cancelled(): # Completed on its own — retrieve any surfaced exception so it # is not reported as "never retrieved", and keep its status. task.exception() # type: ignore[union-attr] # done & not cancelled continue if record.status in (RunStatus.pending, RunStatus.running): record.status = RunStatus.interrupted record.updated_at = _now_iso() to_persist.append(record) # Bound the trailing status persistence within the remaining budget so a # slow store (``_call_store_with_retry`` can back off under DB pressure) # cannot push shutdown past ``timeout``. if to_persist: remaining = deadline - loop.time() if remaining <= 0: logger.warning("Run drain budget exhausted before persisting %d interrupted run(s) on shutdown", len(to_persist)) else: try: results = await asyncio.wait_for( asyncio.gather(*(self._persist_status(record, RunStatus.interrupted) for record in to_persist), return_exceptions=True), timeout=remaining, ) except TimeoutError: logger.warning("Run drain status persistence exceeded the %.1fs budget; %d record(s) may not be persisted", timeout, len(to_persist)) else: # ``_persist_status`` is best-effort: it catches and logs its # own failures, returning ``False``. Inspect the aggregate so a # partial failure is surfaced at shutdown level (with the # run_id) instead of being silently swallowed by the gather. for record, result in zip(to_persist, results): if isinstance(result, Exception): logger.warning("Unexpected error persisting interrupted status for run %s during shutdown: %r", record.run_id, result) elif result is False: logger.warning("Could not persist interrupted status for run %s during shutdown", record.run_id) if pending: logger.warning("Run drain exceeded %.1fs on shutdown; %d run task(s) still active and may race checkpointer teardown", timeout, len(pending)) logger.info("Drained %d in-flight run(s) on shutdown (%d settled within %.1fs)", len(inflight), len(inflight) - len(pending), timeout) class ConflictError(Exception): """Raised when multitask_strategy=reject and thread has inflight runs.""" class UnsupportedStrategyError(Exception): """Raised when a multitask_strategy value is not yet implemented."""