fix: harden run finalization persistence (#3155)

* fix: harden run finalization persistence

* style: format gateway recovery test

* fix: align run repository return types

* fix: harden completion recovery follow-up
This commit is contained in:
AochenShen99
2026-05-23 00:09:06 +08:00
committed by GitHub
parent f0bae28636
commit 66d6a6a4e8
8 changed files with 755 additions and 56 deletions
@@ -4,7 +4,9 @@ from __future__ import annotations
import asyncio
import logging
import sqlite3
import uuid
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
@@ -17,6 +19,57 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
_RETRYABLE_SQLITE_MESSAGES = (
"database is locked",
"database table is locked",
"database is busy",
)
_RETRYABLE_SQLITE_ERROR_CODES = {
sqlite3.SQLITE_BUSY,
sqlite3.SQLITE_LOCKED,
}
def _is_retryable_persistence_error(exc: BaseException) -> bool:
"""Return True for transient SQLite persistence failures.
SQLite lock contention normally surfaces through either sqlite3 exceptions
or SQLAlchemy wrappers. The short bounded retry here protects run status
finalization from transient writer pressure without hiding permanent
failures forever.
"""
pending: list[BaseException] = [exc]
seen: set[int] = set()
while pending:
current = pending.pop()
if id(current) in seen:
continue
seen.add(id(current))
message = str(current).lower()
if any(fragment in message for fragment in _RETRYABLE_SQLITE_MESSAGES):
return True
if isinstance(current, (sqlite3.OperationalError, sqlite3.DatabaseError)):
error_code = getattr(current, "sqlite_errorcode", None)
if error_code in _RETRYABLE_SQLITE_ERROR_CODES:
return True
for chained in (getattr(current, "orig", None), current.__cause__, current.__context__):
if isinstance(chained, BaseException):
pending.append(chained)
return False
@dataclass(frozen=True)
class PersistenceRetryPolicy:
"""Bounded retry policy for short run-store writes."""
max_attempts: int = 5
initial_delay: float = 0.05
max_delay: float = 1.0
backoff_factor: float = 2.0
@dataclass
class RunRecord:
@@ -58,38 +111,100 @@ class RunManager:
that run history survives process restarts.
"""
def __init__(self, store: RunStore | None = None) -> None:
def __init__(
self,
store: RunStore | None = None,
*,
persistence_retry_policy: PersistenceRetryPolicy | None = None,
) -> None:
self._runs: dict[str, RunRecord] = {}
self._lock = asyncio.Lock()
self._store = store
self._persistence_retry_policy = persistence_retry_policy or PersistenceRetryPolicy()
async def _persist_to_store(self, record: RunRecord) -> None:
"""Best-effort persist run record to backing store."""
@staticmethod
def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]:
return {
"thread_id": record.thread_id,
"assistant_id": record.assistant_id,
"status": record.status.value,
"multitask_strategy": record.multitask_strategy,
"metadata": record.metadata or {},
"kwargs": record.kwargs or {},
"error": error if error is not None else record.error,
"created_at": record.created_at,
"model_name": record.model_name,
}
async def _call_store_with_retry(
self,
operation_name: str,
run_id: str,
operation: Callable[[], Awaitable[Any]],
) -> Any:
"""Run a short store operation with bounded retries for SQLite pressure."""
policy = self._persistence_retry_policy
attempt = 1
delay = policy.initial_delay
while True:
try:
return await operation()
except Exception as exc:
retryable = _is_retryable_persistence_error(exc)
if attempt >= policy.max_attempts or not retryable:
raise
logger.warning(
"Transient persistence failure during %s for run %s (attempt %d/%d); retrying",
operation_name,
run_id,
attempt,
policy.max_attempts,
exc_info=True,
)
if delay > 0:
await asyncio.sleep(delay)
delay = min(policy.max_delay, delay * policy.backoff_factor if delay else policy.initial_delay)
attempt += 1
async def _persist_snapshot_to_store(self, run_id: str, payload: dict[str, Any]) -> bool:
"""Best-effort persist a previously captured run snapshot."""
if self._store is None:
return
return True
try:
await self._store.put(
record.run_id,
thread_id=record.thread_id,
assistant_id=record.assistant_id,
status=record.status.value,
multitask_strategy=record.multitask_strategy,
metadata=record.metadata or {},
kwargs=record.kwargs or {},
created_at=record.created_at,
model_name=record.model_name,
await self._call_store_with_retry(
"put",
run_id,
lambda: self._store.put(run_id, **payload),
)
return True
except Exception:
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
logger.warning("Failed to persist run %s to store", run_id, exc_info=True)
return False
async def _persist_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
async def _persist_to_store(self, record: RunRecord, *, error: str | None = None) -> bool:
"""Best-effort persist run record to backing store."""
return await self._persist_snapshot_to_store(
record.run_id,
self._store_put_payload(record, error=error),
)
async def _persist_status(self, record: RunRecord, status: RunStatus, *, error: str | None = None) -> bool:
"""Best-effort persist a status transition to the backing store."""
if self._store is None:
return
return True
row_recovery_payload = self._store_put_payload(record, error=error)
try:
await self._store.update_status(run_id, status.value, error=error)
updated = await self._call_store_with_retry(
"update_status",
record.run_id,
lambda: self._store.update_status(record.run_id, status.value, error=error),
)
if updated is False:
return await self._persist_snapshot_to_store(record.run_id, row_recovery_payload)
return True
except Exception:
logger.warning("Failed to persist status update for run %s", run_id, exc_info=True)
logger.warning("Failed to persist status update for run %s", record.run_id, exc_info=True)
return False
@staticmethod
def _record_from_store(row: dict[str, Any]) -> RunRecord:
@@ -126,6 +241,7 @@ class RunManager:
async def update_run_completion(self, run_id: str, **kwargs) -> None:
"""Persist token usage and completion data to the backing store."""
row_recovery_payload: dict[str, Any] | None = None
async with self._lock:
record = self._runs.get(run_id)
if record is not None:
@@ -135,11 +251,30 @@ class RunManager:
if hasattr(record, key) and value is not None:
setattr(record, key, value)
record.updated_at = _now_iso()
if self._store is not None:
try:
await self._store.update_run_completion(run_id, **kwargs)
except Exception:
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
row_recovery_payload = self._store_put_payload(record, error=kwargs.get("error"))
if self._store is None:
return
try:
updated = await self._call_store_with_retry(
"update_run_completion",
run_id,
lambda: self._store.update_run_completion(run_id, **kwargs),
)
if updated is False:
if row_recovery_payload is None:
logger.warning("Failed to recreate missing run %s for completion persistence", run_id)
return
if not await self._persist_snapshot_to_store(run_id, row_recovery_payload):
return
recovered = await self._call_store_with_retry(
"update_run_completion",
run_id,
lambda: self._store.update_run_completion(run_id, **kwargs),
)
if recovered is False:
logger.warning("Run completion update for %s affected no rows after row recreation", run_id)
except Exception:
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
async def update_run_progress(self, run_id: str, **kwargs) -> None:
"""Persist a running token/message snapshot without changing status."""
@@ -273,7 +408,7 @@ class RunManager:
record.updated_at = _now_iso()
if error is not None:
record.error = error
await self._persist_status(run_id, status, error=error)
await self._persist_status(record, status, error=error)
logger.info("Run %s -> %s", run_id, status.value)
async def _persist_model_name(self, run_id: str, model_name: str | None) -> None:
@@ -281,7 +416,11 @@ class RunManager:
if self._store is None:
return
try:
await self._store.update_model_name(run_id, model_name)
await self._call_store_with_retry(
"update_model_name",
run_id,
lambda: self._store.update_model_name(run_id, model_name),
)
except Exception:
logger.warning("Failed to persist model_name update for run %s", run_id, exc_info=True)
@@ -324,7 +463,7 @@ class RunManager:
record.task.cancel()
record.status = RunStatus.interrupted
record.updated_at = _now_iso()
await self._persist_status(run_id, RunStatus.interrupted)
await self._persist_status(record, RunStatus.interrupted)
logger.info("Run %s cancelled (action=%s)", run_id, action)
return True
@@ -352,7 +491,7 @@ class RunManager:
now = _now_iso()
_supported_strategies = ("reject", "interrupt", "rollback")
interrupted_run_ids: list[str] = []
interrupted_records: list[RunRecord] = []
async with self._lock:
if multitask_strategy not in _supported_strategies:
@@ -371,7 +510,7 @@ class RunManager:
r.task.cancel()
r.status = RunStatus.interrupted
r.updated_at = now
interrupted_run_ids.append(r.run_id)
interrupted_records.append(r)
logger.info(
"Cancelled %d inflight run(s) on thread %s (strategy=%s)",
len(inflight),
@@ -394,12 +533,66 @@ class RunManager:
)
self._runs[run_id] = record
for interrupted_run_id in interrupted_run_ids:
await self._persist_status(interrupted_run_id, RunStatus.interrupted)
for interrupted_record in interrupted_records:
await self._persist_status(interrupted_record, RunStatus.interrupted)
await self._persist_to_store(record)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record
async def reconcile_orphaned_inflight_runs(
self,
*,
error: str,
before: str | None = None,
) -> list[RunRecord]:
"""Mark persisted active runs as failed when no local task owns them.
Gateway runs are process-local: the asyncio task and abort event live in
memory, while the run row is durable. After a SQLite-backed gateway
restart, any persisted ``pending`` or ``running`` row created before
startup cannot still have a local worker. This recovery step turns that
ambiguous state into an explicit error instead of letting the UI show an
indefinite active run.
"""
if self._store is None:
return []
try:
rows = await self._call_store_with_retry(
"list_inflight",
"*",
lambda: self._store.list_inflight(before=before),
)
except Exception:
logger.warning("Failed to list orphaned inflight runs for reconciliation", exc_info=True)
return []
recovered: list[RunRecord] = []
now = _now_iso()
for row in rows:
try:
record = self._record_from_store(row)
except Exception:
logger.warning("Failed to map orphaned run row during reconciliation", exc_info=True)
continue
async with self._lock:
live_record = self._runs.get(record.run_id)
if live_record is not None and live_record.status in (RunStatus.pending, RunStatus.running):
continue
record.status = RunStatus.error
record.error = error
record.updated_at = now
persisted = await self._persist_status(record, RunStatus.error, error=error)
if not persisted:
logger.warning("Skipped orphaned run %s recovery because error status was not persisted", record.run_id)
continue
recovered.append(record)
if recovered:
logger.warning("Recovered %d orphaned inflight run(s) as error", len(recovered))
return recovered
async def has_inflight(self, thread_id: str) -> bool:
"""Return ``True`` if *thread_id* has a pending or running run."""
async with self._lock:
@@ -59,7 +59,12 @@ class RunStore(abc.ABC):
status: str,
*,
error: str | None = None,
) -> None:
) -> bool | None:
"""Update a run status.
Returns ``False`` when the store can prove no row was updated. Older or
lightweight stores may return ``None`` when they cannot report rowcount.
"""
pass
@abc.abstractmethod
@@ -92,7 +97,11 @@ class RunStore(abc.ABC):
last_ai_message: str | None = None,
first_human_message: str | None = None,
error: str | None = None,
) -> None:
) -> bool | None:
"""Persist final completion fields.
Returns ``False`` when the store can prove no row was updated.
"""
pass
async def update_run_progress(
@@ -117,6 +126,11 @@ class RunStore(abc.ABC):
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
pass
@abc.abstractmethod
async def list_inflight(self, *, before: str | None = None) -> list[dict[str, Any]]:
"""Return persisted runs that are still ``pending`` or ``running``."""
pass
@abc.abstractmethod
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
"""Aggregate token usage for completed runs in a thread.
@@ -65,6 +65,8 @@ class MemoryRunStore(RunStore):
if error is not None:
self._runs[run_id]["error"] = error
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
return True
return False
async def update_model_name(self, run_id, model_name):
if run_id in self._runs:
@@ -81,6 +83,8 @@ class MemoryRunStore(RunStore):
if value is not None:
self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
return True
return False
async def update_run_progress(self, run_id, **kwargs):
if run_id in self._runs and self._runs[run_id].get("status") == "running":
@@ -95,6 +99,12 @@ class MemoryRunStore(RunStore):
results.sort(key=lambda r: r["created_at"])
return results
async def list_inflight(self, *, before=None):
now = before or datetime.now(UTC).isoformat()
results = [r for r in self._runs.values() if r["status"] in ("pending", "running") and r["created_at"] <= now]
results.sort(key=lambda r: r["created_at"])
return results
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
statuses = ("success", "error", "running") if include_active else ("success", "error")
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in statuses]