fix: harden run finalization persistence (#3155)

* fix: harden run finalization persistence

* style: format gateway recovery test

* fix: align run repository return types

* fix: harden completion recovery follow-up
This commit is contained in:
AochenShen99
2026-05-23 00:09:06 +08:00
committed by GitHub
parent f0bae28636
commit 66d6a6a4e8
8 changed files with 755 additions and 56 deletions
+35
View File
@@ -37,11 +37,36 @@ if TYPE_CHECKING:
from app.gateway.auth.local_provider import LocalAuthProvider from app.gateway.auth.local_provider import LocalAuthProvider
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
from deerflow.persistence.thread_meta.base import ThreadMetaStore from deerflow.persistence.thread_meta.base import ThreadMetaStore
from deerflow.runtime import RunRecord
T = TypeVar("T") T = TypeVar("T")
async def _mark_latest_recovered_threads_error(
run_manager: RunManager,
thread_store: ThreadMetaStore,
recovered_runs: list[RunRecord],
) -> None:
"""Mark thread status as error only when its newest run was recovered."""
recovered_by_thread: dict[str, set[str]] = {}
for record in recovered_runs:
recovered_by_thread.setdefault(record.thread_id, set()).add(record.run_id)
for thread_id, recovered_run_ids in recovered_by_thread.items():
try:
latest_runs = await run_manager.list_by_thread(thread_id, user_id=None, limit=1)
except Exception:
logger.warning("Failed to find latest run for thread %s during run reconciliation", thread_id, exc_info=True)
continue
if not latest_runs or latest_runs[0].run_id not in recovered_run_ids:
continue
try:
await thread_store.update_status(thread_id, "error", user_id=None)
except Exception:
logger.warning("Failed to mark thread %s as error during run reconciliation", thread_id, exc_info=True)
def get_config() -> AppConfig: def get_config() -> AppConfig:
"""Return the freshest ``AppConfig`` for the current request. """Return the freshest ``AppConfig`` for the current request.
@@ -138,6 +163,16 @@ async def langgraph_runtime(app: FastAPI, startup_config: AppConfig) -> AsyncGen
# RunManager with store backing for persistence # RunManager with store backing for persistence
app.state.run_manager = RunManager(store=app.state.run_store) app.state.run_manager = RunManager(store=app.state.run_store)
if getattr(config.database, "backend", None) == "sqlite":
from deerflow.utils.time import now_iso
# Startup-only recovery: clean shutdowns return no active rows and
# the thread-status update below becomes a no-op.
recovered_runs = await app.state.run_manager.reconcile_orphaned_inflight_runs(
error="Gateway restarted before this run reached a durable final state.",
before=now_iso(),
)
await _mark_latest_recovered_threads_error(app.state.run_manager, app.state.thread_store, recovered_runs)
try: try:
yield yield
@@ -94,25 +94,35 @@ class RunRepository(RunStore):
created_at=None, created_at=None,
follow_up_to_run_id=None, follow_up_to_run_id=None,
): ):
"""Insert or update a run row.
``RunManager`` retries ``put`` after transient SQLite failures. Making
this operation idempotent prevents a successful-but-unacknowledged first
commit from turning the retry into a primary-key failure.
"""
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.put") resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.put")
now = datetime.now(UTC) now = datetime.now(UTC)
row = RunRow( created = datetime.fromisoformat(created_at) if created_at else now
run_id=run_id, values = {
thread_id=thread_id, "thread_id": thread_id,
assistant_id=assistant_id, "assistant_id": assistant_id,
user_id=resolved_user_id, "user_id": resolved_user_id,
model_name=self._normalize_model_name(model_name), "model_name": self._normalize_model_name(model_name),
status=status, "status": status,
multitask_strategy=multitask_strategy, "multitask_strategy": multitask_strategy,
metadata_json=self._safe_json(metadata) or {}, "metadata_json": self._safe_json(metadata) or {},
kwargs_json=self._safe_json(kwargs) or {}, "kwargs_json": self._safe_json(kwargs) or {},
error=error, "error": error,
follow_up_to_run_id=follow_up_to_run_id, "follow_up_to_run_id": follow_up_to_run_id,
created_at=datetime.fromisoformat(created_at) if created_at else now, "updated_at": now,
updated_at=now, }
)
async with self._sf() as session: async with self._sf() as session:
session.add(row) row = await session.get(RunRow, run_id)
if row is None:
session.add(RunRow(run_id=run_id, created_at=created, **values))
else:
for key, value in values.items():
setattr(row, key, value)
await session.commit() await session.commit()
async def get( async def get(
@@ -146,13 +156,14 @@ class RunRepository(RunStore):
result = await session.execute(stmt) result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()] return [self._row_to_dict(r) for r in result.scalars()]
async def update_status(self, run_id, status, *, error=None): async def update_status(self, run_id, status, *, error=None) -> bool:
values: dict[str, Any] = {"status": status, "updated_at": datetime.now(UTC)} values: dict[str, Any] = {"status": status, "updated_at": datetime.now(UTC)}
if error is not None: if error is not None:
values["error"] = error values["error"] = error
async with self._sf() as session: async with self._sf() as session:
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) result = await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit() await session.commit()
return result.rowcount != 0
async def update_model_name(self, run_id, model_name): async def update_model_name(self, run_id, model_name):
async with self._sf() as session: async with self._sf() as session:
@@ -187,6 +198,26 @@ class RunRepository(RunStore):
result = await session.execute(stmt) result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()] return [self._row_to_dict(r) for r in result.scalars()]
async def list_inflight(self, *, before=None):
"""Return persisted active runs for startup recovery."""
if before is None:
before_dt = datetime.now(UTC)
elif isinstance(before, datetime):
before_dt = before
else:
before_dt = datetime.fromisoformat(before)
stmt = (
select(RunRow)
.where(
RunRow.status.in_(("pending", "running")),
RunRow.created_at <= before_dt,
)
.order_by(RunRow.created_at.asc())
)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def update_run_completion( async def update_run_completion(
self, self,
run_id: str, run_id: str,
@@ -203,8 +234,11 @@ class RunRepository(RunStore):
last_ai_message: str | None = None, last_ai_message: str | None = None,
first_human_message: str | None = None, first_human_message: str | None = None,
error: str | None = None, error: str | None = None,
) -> None: ) -> bool:
"""Update status + token usage + convenience fields on run completion.""" """Update status + token usage + convenience fields on run completion.
Returns ``False`` when no run row matched the requested ``run_id``.
"""
values: dict[str, Any] = { values: dict[str, Any] = {
"status": status, "status": status,
"total_input_tokens": total_input_tokens, "total_input_tokens": total_input_tokens,
@@ -224,8 +258,9 @@ class RunRepository(RunStore):
if error is not None: if error is not None:
values["error"] = error values["error"] = error
async with self._sf() as session: async with self._sf() as session:
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) result = await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit() await session.commit()
return result.rowcount != 0
async def update_run_progress( async def update_run_progress(
self, self,
@@ -4,7 +4,9 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import sqlite3
import uuid import uuid
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@@ -17,6 +19,57 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_RETRYABLE_SQLITE_MESSAGES = (
"database is locked",
"database table is locked",
"database is busy",
)
_RETRYABLE_SQLITE_ERROR_CODES = {
sqlite3.SQLITE_BUSY,
sqlite3.SQLITE_LOCKED,
}
def _is_retryable_persistence_error(exc: BaseException) -> bool:
"""Return True for transient SQLite persistence failures.
SQLite lock contention normally surfaces through either sqlite3 exceptions
or SQLAlchemy wrappers. The short bounded retry here protects run status
finalization from transient writer pressure without hiding permanent
failures forever.
"""
pending: list[BaseException] = [exc]
seen: set[int] = set()
while pending:
current = pending.pop()
if id(current) in seen:
continue
seen.add(id(current))
message = str(current).lower()
if any(fragment in message for fragment in _RETRYABLE_SQLITE_MESSAGES):
return True
if isinstance(current, (sqlite3.OperationalError, sqlite3.DatabaseError)):
error_code = getattr(current, "sqlite_errorcode", None)
if error_code in _RETRYABLE_SQLITE_ERROR_CODES:
return True
for chained in (getattr(current, "orig", None), current.__cause__, current.__context__):
if isinstance(chained, BaseException):
pending.append(chained)
return False
@dataclass(frozen=True)
class PersistenceRetryPolicy:
"""Bounded retry policy for short run-store writes."""
max_attempts: int = 5
initial_delay: float = 0.05
max_delay: float = 1.0
backoff_factor: float = 2.0
@dataclass @dataclass
class RunRecord: class RunRecord:
@@ -58,38 +111,100 @@ class RunManager:
that run history survives process restarts. that run history survives process restarts.
""" """
def __init__(self, store: RunStore | None = None) -> None: def __init__(
self,
store: RunStore | None = None,
*,
persistence_retry_policy: PersistenceRetryPolicy | None = None,
) -> None:
self._runs: dict[str, RunRecord] = {} self._runs: dict[str, RunRecord] = {}
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._store = store self._store = store
self._persistence_retry_policy = persistence_retry_policy or PersistenceRetryPolicy()
async def _persist_to_store(self, record: RunRecord) -> None: @staticmethod
"""Best-effort persist run record to backing store.""" def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]:
if self._store is None: return {
return "thread_id": record.thread_id,
"assistant_id": record.assistant_id,
"status": record.status.value,
"multitask_strategy": record.multitask_strategy,
"metadata": record.metadata or {},
"kwargs": record.kwargs or {},
"error": error if error is not None else record.error,
"created_at": record.created_at,
"model_name": record.model_name,
}
async def _call_store_with_retry(
self,
operation_name: str,
run_id: str,
operation: Callable[[], Awaitable[Any]],
) -> Any:
"""Run a short store operation with bounded retries for SQLite pressure."""
policy = self._persistence_retry_policy
attempt = 1
delay = policy.initial_delay
while True:
try: try:
await self._store.put( return await operation()
record.run_id, except Exception as exc:
thread_id=record.thread_id, retryable = _is_retryable_persistence_error(exc)
assistant_id=record.assistant_id, if attempt >= policy.max_attempts or not retryable:
status=record.status.value, raise
multitask_strategy=record.multitask_strategy, logger.warning(
metadata=record.metadata or {}, "Transient persistence failure during %s for run %s (attempt %d/%d); retrying",
kwargs=record.kwargs or {}, operation_name,
created_at=record.created_at, run_id,
model_name=record.model_name, attempt,
policy.max_attempts,
exc_info=True,
) )
except Exception: if delay > 0:
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True) await asyncio.sleep(delay)
delay = min(policy.max_delay, delay * policy.backoff_factor if delay else policy.initial_delay)
attempt += 1
async def _persist_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None: async def _persist_snapshot_to_store(self, run_id: str, payload: dict[str, Any]) -> bool:
"""Best-effort persist a previously captured run snapshot."""
if self._store is None:
return True
try:
await self._call_store_with_retry(
"put",
run_id,
lambda: self._store.put(run_id, **payload),
)
return True
except Exception:
logger.warning("Failed to persist run %s to store", run_id, exc_info=True)
return False
async def _persist_to_store(self, record: RunRecord, *, error: str | None = None) -> bool:
"""Best-effort persist run record to backing store."""
return await self._persist_snapshot_to_store(
record.run_id,
self._store_put_payload(record, error=error),
)
async def _persist_status(self, record: RunRecord, status: RunStatus, *, error: str | None = None) -> bool:
"""Best-effort persist a status transition to the backing store.""" """Best-effort persist a status transition to the backing store."""
if self._store is None: if self._store is None:
return return True
row_recovery_payload = self._store_put_payload(record, error=error)
try: try:
await self._store.update_status(run_id, status.value, error=error) updated = await self._call_store_with_retry(
"update_status",
record.run_id,
lambda: self._store.update_status(record.run_id, status.value, error=error),
)
if updated is False:
return await self._persist_snapshot_to_store(record.run_id, row_recovery_payload)
return True
except Exception: except Exception:
logger.warning("Failed to persist status update for run %s", run_id, exc_info=True) logger.warning("Failed to persist status update for run %s", record.run_id, exc_info=True)
return False
@staticmethod @staticmethod
def _record_from_store(row: dict[str, Any]) -> RunRecord: def _record_from_store(row: dict[str, Any]) -> RunRecord:
@@ -126,6 +241,7 @@ class RunManager:
async def update_run_completion(self, run_id: str, **kwargs) -> None: async def update_run_completion(self, run_id: str, **kwargs) -> None:
"""Persist token usage and completion data to the backing store.""" """Persist token usage and completion data to the backing store."""
row_recovery_payload: dict[str, Any] | None = None
async with self._lock: async with self._lock:
record = self._runs.get(run_id) record = self._runs.get(run_id)
if record is not None: if record is not None:
@@ -135,9 +251,28 @@ class RunManager:
if hasattr(record, key) and value is not None: if hasattr(record, key) and value is not None:
setattr(record, key, value) setattr(record, key, value)
record.updated_at = _now_iso() record.updated_at = _now_iso()
if self._store is not None: row_recovery_payload = self._store_put_payload(record, error=kwargs.get("error"))
if self._store is None:
return
try: try:
await self._store.update_run_completion(run_id, **kwargs) updated = await self._call_store_with_retry(
"update_run_completion",
run_id,
lambda: self._store.update_run_completion(run_id, **kwargs),
)
if updated is False:
if row_recovery_payload is None:
logger.warning("Failed to recreate missing run %s for completion persistence", run_id)
return
if not await self._persist_snapshot_to_store(run_id, row_recovery_payload):
return
recovered = await self._call_store_with_retry(
"update_run_completion",
run_id,
lambda: self._store.update_run_completion(run_id, **kwargs),
)
if recovered is False:
logger.warning("Run completion update for %s affected no rows after row recreation", run_id)
except Exception: except Exception:
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True) logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
@@ -273,7 +408,7 @@ class RunManager:
record.updated_at = _now_iso() record.updated_at = _now_iso()
if error is not None: if error is not None:
record.error = error record.error = error
await self._persist_status(run_id, status, error=error) await self._persist_status(record, status, error=error)
logger.info("Run %s -> %s", run_id, status.value) logger.info("Run %s -> %s", run_id, status.value)
async def _persist_model_name(self, run_id: str, model_name: str | None) -> None: async def _persist_model_name(self, run_id: str, model_name: str | None) -> None:
@@ -281,7 +416,11 @@ class RunManager:
if self._store is None: if self._store is None:
return return
try: try:
await self._store.update_model_name(run_id, model_name) await self._call_store_with_retry(
"update_model_name",
run_id,
lambda: self._store.update_model_name(run_id, model_name),
)
except Exception: except Exception:
logger.warning("Failed to persist model_name update for run %s", run_id, exc_info=True) logger.warning("Failed to persist model_name update for run %s", run_id, exc_info=True)
@@ -324,7 +463,7 @@ class RunManager:
record.task.cancel() record.task.cancel()
record.status = RunStatus.interrupted record.status = RunStatus.interrupted
record.updated_at = _now_iso() record.updated_at = _now_iso()
await self._persist_status(run_id, RunStatus.interrupted) await self._persist_status(record, RunStatus.interrupted)
logger.info("Run %s cancelled (action=%s)", run_id, action) logger.info("Run %s cancelled (action=%s)", run_id, action)
return True return True
@@ -352,7 +491,7 @@ class RunManager:
now = _now_iso() now = _now_iso()
_supported_strategies = ("reject", "interrupt", "rollback") _supported_strategies = ("reject", "interrupt", "rollback")
interrupted_run_ids: list[str] = [] interrupted_records: list[RunRecord] = []
async with self._lock: async with self._lock:
if multitask_strategy not in _supported_strategies: if multitask_strategy not in _supported_strategies:
@@ -371,7 +510,7 @@ class RunManager:
r.task.cancel() r.task.cancel()
r.status = RunStatus.interrupted r.status = RunStatus.interrupted
r.updated_at = now r.updated_at = now
interrupted_run_ids.append(r.run_id) interrupted_records.append(r)
logger.info( logger.info(
"Cancelled %d inflight run(s) on thread %s (strategy=%s)", "Cancelled %d inflight run(s) on thread %s (strategy=%s)",
len(inflight), len(inflight),
@@ -394,12 +533,66 @@ class RunManager:
) )
self._runs[run_id] = record self._runs[run_id] = record
for interrupted_run_id in interrupted_run_ids: for interrupted_record in interrupted_records:
await self._persist_status(interrupted_run_id, RunStatus.interrupted) await self._persist_status(interrupted_record, RunStatus.interrupted)
await self._persist_to_store(record) await self._persist_to_store(record)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record return record
async def reconcile_orphaned_inflight_runs(
self,
*,
error: str,
before: str | None = None,
) -> list[RunRecord]:
"""Mark persisted active runs as failed when no local task owns them.
Gateway runs are process-local: the asyncio task and abort event live in
memory, while the run row is durable. After a SQLite-backed gateway
restart, any persisted ``pending`` or ``running`` row created before
startup cannot still have a local worker. This recovery step turns that
ambiguous state into an explicit error instead of letting the UI show an
indefinite active run.
"""
if self._store is None:
return []
try:
rows = await self._call_store_with_retry(
"list_inflight",
"*",
lambda: self._store.list_inflight(before=before),
)
except Exception:
logger.warning("Failed to list orphaned inflight runs for reconciliation", exc_info=True)
return []
recovered: list[RunRecord] = []
now = _now_iso()
for row in rows:
try:
record = self._record_from_store(row)
except Exception:
logger.warning("Failed to map orphaned run row during reconciliation", exc_info=True)
continue
async with self._lock:
live_record = self._runs.get(record.run_id)
if live_record is not None and live_record.status in (RunStatus.pending, RunStatus.running):
continue
record.status = RunStatus.error
record.error = error
record.updated_at = now
persisted = await self._persist_status(record, RunStatus.error, error=error)
if not persisted:
logger.warning("Skipped orphaned run %s recovery because error status was not persisted", record.run_id)
continue
recovered.append(record)
if recovered:
logger.warning("Recovered %d orphaned inflight run(s) as error", len(recovered))
return recovered
async def has_inflight(self, thread_id: str) -> bool: async def has_inflight(self, thread_id: str) -> bool:
"""Return ``True`` if *thread_id* has a pending or running run.""" """Return ``True`` if *thread_id* has a pending or running run."""
async with self._lock: async with self._lock:
@@ -59,7 +59,12 @@ class RunStore(abc.ABC):
status: str, status: str,
*, *,
error: str | None = None, error: str | None = None,
) -> None: ) -> bool | None:
"""Update a run status.
Returns ``False`` when the store can prove no row was updated. Older or
lightweight stores may return ``None`` when they cannot report rowcount.
"""
pass pass
@abc.abstractmethod @abc.abstractmethod
@@ -92,7 +97,11 @@ class RunStore(abc.ABC):
last_ai_message: str | None = None, last_ai_message: str | None = None,
first_human_message: str | None = None, first_human_message: str | None = None,
error: str | None = None, error: str | None = None,
) -> None: ) -> bool | None:
"""Persist final completion fields.
Returns ``False`` when the store can prove no row was updated.
"""
pass pass
async def update_run_progress( async def update_run_progress(
@@ -117,6 +126,11 @@ class RunStore(abc.ABC):
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]: async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
pass pass
@abc.abstractmethod
async def list_inflight(self, *, before: str | None = None) -> list[dict[str, Any]]:
"""Return persisted runs that are still ``pending`` or ``running``."""
pass
@abc.abstractmethod @abc.abstractmethod
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]: async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
"""Aggregate token usage for completed runs in a thread. """Aggregate token usage for completed runs in a thread.
@@ -65,6 +65,8 @@ class MemoryRunStore(RunStore):
if error is not None: if error is not None:
self._runs[run_id]["error"] = error self._runs[run_id]["error"] = error
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat() self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
return True
return False
async def update_model_name(self, run_id, model_name): async def update_model_name(self, run_id, model_name):
if run_id in self._runs: if run_id in self._runs:
@@ -81,6 +83,8 @@ class MemoryRunStore(RunStore):
if value is not None: if value is not None:
self._runs[run_id][key] = value self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat() self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
return True
return False
async def update_run_progress(self, run_id, **kwargs): async def update_run_progress(self, run_id, **kwargs):
if run_id in self._runs and self._runs[run_id].get("status") == "running": if run_id in self._runs and self._runs[run_id].get("status") == "running":
@@ -95,6 +99,12 @@ class MemoryRunStore(RunStore):
results.sort(key=lambda r: r["created_at"]) results.sort(key=lambda r: r["created_at"])
return results return results
async def list_inflight(self, *, before=None):
now = before or datetime.now(UTC).isoformat()
results = [r for r in self._runs.values() if r["status"] in ("pending", "running") and r["created_at"] <= now]
results.sort(key=lambda r: r["created_at"])
return results
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]: async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
statuses = ("success", "error", "running") if include_active else ("success", "error") statuses = ("success", "error", "running") if include_active else ("success", "error")
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in statuses] completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in statuses]
+127
View File
@@ -0,0 +1,127 @@
"""Gateway startup recovery for stale persisted runs."""
from __future__ import annotations
from contextlib import asynccontextmanager
from types import SimpleNamespace
import pytest
from fastapi import FastAPI
import deerflow.runtime as runtime_module
from app.gateway import deps as gateway_deps
from deerflow.persistence import engine as engine_module
from deerflow.persistence import thread_meta as thread_meta_module
from deerflow.runtime.checkpointer import async_provider as checkpointer_module
from deerflow.runtime.events import store as event_store_module
@asynccontextmanager
async def _fake_context(value):
yield value
class _FakeRunManager:
"""RunManager double that records startup reconciliation calls."""
instances: list[_FakeRunManager] = []
recovered_runs = [SimpleNamespace(run_id="run-1", thread_id="thread-1")]
latest_by_thread: dict[str, list[SimpleNamespace]] = {}
def __init__(self, *, store):
self.store = store
self.reconcile_calls: list[dict] = []
self.list_by_thread_calls: list[dict] = []
_FakeRunManager.instances.append(self)
async def reconcile_orphaned_inflight_runs(self, *, error: str, before: str | None = None):
self.reconcile_calls.append({"error": error, "before": before})
return self.recovered_runs
async def list_by_thread(self, thread_id: str, *, user_id=None, limit: int = 100):
self.list_by_thread_calls.append({"thread_id": thread_id, "user_id": user_id, "limit": limit})
return self.latest_by_thread.get(thread_id, self.recovered_runs[:limit])
class _FakeThreadStore:
def __init__(self) -> None:
self.status_updates: list[tuple[str, str, str | None]] = []
async def update_status(self, thread_id: str, status: str, *, user_id=None) -> None:
self.status_updates.append((thread_id, status, user_id))
@pytest.mark.anyio
async def test_sqlite_runtime_reconciles_orphaned_runs_on_startup(monkeypatch):
"""SQLite startup should recover stale active runs before serving requests."""
app = FastAPI()
config = SimpleNamespace(
database=SimpleNamespace(backend="sqlite"),
run_events=SimpleNamespace(backend="memory"),
)
thread_store = _FakeThreadStore()
_FakeRunManager.instances.clear()
_FakeRunManager.recovered_runs = [SimpleNamespace(run_id="run-1", thread_id="thread-1")]
_FakeRunManager.latest_by_thread = {}
async def fake_init_engine_from_config(_database):
return None
async def fake_close_engine():
return None
monkeypatch.setattr(engine_module, "init_engine_from_config", fake_init_engine_from_config)
monkeypatch.setattr(engine_module, "get_session_factory", lambda: None)
monkeypatch.setattr(engine_module, "close_engine", fake_close_engine)
monkeypatch.setattr(runtime_module, "make_stream_bridge", lambda _config: _fake_context(object()))
monkeypatch.setattr(checkpointer_module, "make_checkpointer", lambda _config: _fake_context(object()))
monkeypatch.setattr(runtime_module, "make_store", lambda _config: _fake_context(object()))
monkeypatch.setattr(thread_meta_module, "make_thread_store", lambda _sf, _store: thread_store)
monkeypatch.setattr(event_store_module, "make_run_event_store", lambda _config: object())
monkeypatch.setattr(gateway_deps, "RunManager", _FakeRunManager)
async with gateway_deps.langgraph_runtime(app, config):
pass
assert len(_FakeRunManager.instances) == 1
assert _FakeRunManager.instances[0].reconcile_calls
assert _FakeRunManager.instances[0].reconcile_calls[0]["error"]
assert _FakeRunManager.instances[0].list_by_thread_calls == [{"thread_id": "thread-1", "user_id": None, "limit": 1}]
assert thread_store.status_updates == [("thread-1", "error", None)]
@pytest.mark.anyio
async def test_sqlite_runtime_does_not_mark_thread_error_when_newer_run_is_success(monkeypatch):
"""Startup recovery should not let an old orphaned run overwrite a newer terminal thread state."""
app = FastAPI()
config = SimpleNamespace(
database=SimpleNamespace(backend="sqlite"),
run_events=SimpleNamespace(backend="memory"),
)
thread_store = _FakeThreadStore()
_FakeRunManager.instances.clear()
_FakeRunManager.recovered_runs = [SimpleNamespace(run_id="old-running", thread_id="thread-1")]
_FakeRunManager.latest_by_thread = {"thread-1": [SimpleNamespace(run_id="newer-success", thread_id="thread-1", status="success")]}
async def fake_init_engine_from_config(_database):
return None
async def fake_close_engine():
return None
monkeypatch.setattr(engine_module, "init_engine_from_config", fake_init_engine_from_config)
monkeypatch.setattr(engine_module, "get_session_factory", lambda: None)
monkeypatch.setattr(engine_module, "close_engine", fake_close_engine)
monkeypatch.setattr(runtime_module, "make_stream_bridge", lambda _config: _fake_context(object()))
monkeypatch.setattr(checkpointer_module, "make_checkpointer", lambda _config: _fake_context(object()))
monkeypatch.setattr(runtime_module, "make_store", lambda _config: _fake_context(object()))
monkeypatch.setattr(thread_meta_module, "make_thread_store", lambda _sf, _store: thread_store)
monkeypatch.setattr(event_store_module, "make_run_event_store", lambda _config: object())
monkeypatch.setattr(gateway_deps, "RunManager", _FakeRunManager)
async with gateway_deps.langgraph_runtime(app, config):
pass
assert len(_FakeRunManager.instances) == 1
assert _FakeRunManager.instances[0].list_by_thread_calls == [{"thread_id": "thread-1", "user_id": None, "limit": 1}]
assert thread_store.status_updates == []
+240
View File
@@ -1,10 +1,15 @@
"""Tests for RunManager.""" """Tests for RunManager."""
import logging
import re import re
import sqlite3
from typing import Any
import pytest import pytest
from sqlalchemy.exc import DatabaseError as SQLAlchemyDatabaseError
from deerflow.runtime import DisconnectMode, RunManager, RunStatus from deerflow.runtime import DisconnectMode, RunManager, RunStatus
from deerflow.runtime.runs.manager import PersistenceRetryPolicy
from deerflow.runtime.runs.store.memory import MemoryRunStore from deerflow.runtime.runs.store.memory import MemoryRunStore
ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}") ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
@@ -15,6 +20,92 @@ def manager() -> RunManager:
return RunManager() return RunManager()
class FlakyStatusRunStore(MemoryRunStore):
"""Memory run store that simulates transient SQLite status-write failures."""
def __init__(self, *, status_failures: int) -> None:
super().__init__()
self.status_failures = status_failures
self.status_update_attempts = 0
async def update_status(self, run_id, status, *, error=None):
self.status_update_attempts += 1
if self.status_failures > 0:
self.status_failures -= 1
raise sqlite3.OperationalError("database is locked")
return await super().update_status(run_id, status, error=error)
class MissingRowStatusRunStore(MemoryRunStore):
"""Memory run store that reports a missing row for status updates."""
async def update_status(self, run_id, status, *, error=None):
await super().update_status(run_id, status, error=error)
return False
class PermanentStatusRunStore(MemoryRunStore):
"""Memory run store that simulates a permanent SQLAlchemy write failure."""
def __init__(self) -> None:
super().__init__()
self.status_update_attempts = 0
async def update_status(self, run_id, status, *, error=None):
self.status_update_attempts += 1
raise SQLAlchemyDatabaseError(
"UPDATE runs SET status = :status WHERE run_id = :run_id",
{"status": status, "run_id": run_id},
sqlite3.DatabaseError("no such table: runs"),
)
class FailingStatusRunStore(MemoryRunStore):
"""Memory run store that always fails status updates."""
def __init__(self) -> None:
super().__init__()
self.status_update_attempts = 0
async def update_status(self, run_id, status, *, error=None):
self.status_update_attempts += 1
raise sqlite3.OperationalError("database is locked")
class MissingCompletionRunStore(MemoryRunStore):
"""Memory run store that reports one missing row for completion updates."""
def __init__(self) -> None:
super().__init__()
self.completion_update_attempts = 0
async def update_run_completion(self, run_id, *, status, **kwargs):
self.completion_update_attempts += 1
if self.completion_update_attempts == 1:
return False
return await super().update_run_completion(run_id, status=status, **kwargs)
class AlwaysMissingCompletionRunStore(MemoryRunStore):
"""Memory run store that keeps reporting missing rows for completion updates."""
def __init__(self) -> None:
super().__init__()
self.completion_update_attempts = 0
async def update_run_completion(self, run_id, *, status, **kwargs):
self.completion_update_attempts += 1
return False
async def _stored_statuses(store: MemoryRunStore, *run_ids: str) -> dict[str, Any]:
rows = {}
for run_id in run_ids:
row = await store.get(run_id)
rows[run_id] = row["status"] if row else None
return rows
@pytest.mark.anyio @pytest.mark.anyio
async def test_create_and_get(manager: RunManager): async def test_create_and_get(manager: RunManager):
"""Created run should be retrievable with new fields.""" """Created run should be retrievable with new fields."""
@@ -80,6 +171,155 @@ async def test_cancel_persists_interrupted_status_to_store():
assert stored["status"] == "interrupted" assert stored["status"] == "interrupted"
@pytest.mark.anyio
async def test_status_persistence_retries_transient_sqlite_lock():
"""Transient SQLite lock errors should not leave a final status stale."""
store = FlakyStatusRunStore(status_failures=2)
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.running)
await manager.set_status(record.run_id, RunStatus.success)
stored = await store.get(record.run_id)
assert stored is not None
assert stored["status"] == "success"
assert store.status_update_attempts >= 4
@pytest.mark.anyio
async def test_status_persistence_recreates_missing_store_row():
"""A final status update should recreate a run row if initial persistence was lost."""
store = MissingRowStatusRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await store.delete(record.run_id)
await manager.set_status(record.run_id, RunStatus.error, error="boom")
stored = await store.get(record.run_id)
assert stored is not None
assert stored["status"] == "error"
assert stored["error"] == "boom"
@pytest.mark.anyio
async def test_status_persistence_does_not_retry_permanent_sqlalchemy_errors():
"""Permanent SQLAlchemy failures should not be retried as SQLite pressure."""
store = PermanentStatusRunStore()
manager = RunManager(
store=store,
persistence_retry_policy=PersistenceRetryPolicy(max_attempts=5, initial_delay=0),
)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.error, error="boom")
assert store.status_update_attempts == 1
@pytest.mark.anyio
async def test_completion_persistence_recreates_missing_store_row():
"""Completion updates should recreate a missing row and persist final counters."""
store = MissingCompletionRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.running)
await manager.set_status(record.run_id, RunStatus.success)
await store.delete(record.run_id)
await manager.update_run_completion(
record.run_id,
status="success",
total_tokens=42,
llm_call_count=2,
last_ai_message="done",
)
stored = await store.get(record.run_id)
assert stored is not None
assert stored["status"] == "success"
assert stored["total_tokens"] == 42
assert stored["llm_call_count"] == 2
assert stored["last_ai_message"] == "done"
assert store.completion_update_attempts == 2
@pytest.mark.anyio
async def test_completion_persistence_warns_when_recreated_row_still_missing(caplog):
"""A second zero-row completion update after recreation should not be silent."""
store = AlwaysMissingCompletionRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.success)
caplog.set_level(logging.WARNING, logger="deerflow.runtime.runs.manager")
await manager.update_run_completion(record.run_id, status="success", total_tokens=42)
assert store.completion_update_attempts == 2
assert "affected no rows after row recreation" in caplog.text
@pytest.mark.anyio
async def test_reconcile_orphaned_inflight_runs_marks_stale_rows_error():
"""Startup recovery should turn persisted active rows into explicit errors."""
store = MemoryRunStore()
await store.put("pending-run", thread_id="thread-1", status="pending", created_at="2026-01-01T00:00:00+00:00")
await store.put("running-run", thread_id="thread-1", status="running", created_at="2026-01-01T00:00:01+00:00")
await store.put("success-run", thread_id="thread-1", status="success", created_at="2026-01-01T00:00:02+00:00")
manager = RunManager(store=store)
recovered = await manager.reconcile_orphaned_inflight_runs(
error="Gateway restarted before this run reached a durable final state.",
before="2026-01-01T00:00:02+00:00",
)
assert {record.run_id for record in recovered} == {"pending-run", "running-run"}
assert await _stored_statuses(store, "pending-run", "running-run", "success-run") == {
"pending-run": "error",
"running-run": "error",
"success-run": "success",
}
@pytest.mark.anyio
async def test_reconcile_orphaned_inflight_runs_skips_live_local_run():
"""Startup recovery should not mark an active row orphaned when this worker owns it."""
store = MemoryRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.running)
recovered = await manager.reconcile_orphaned_inflight_runs(
error="Gateway restarted before this run reached a durable final state.",
)
stored = await store.get(record.run_id)
assert recovered == []
assert stored["status"] == "running"
@pytest.mark.anyio
async def test_reconcile_orphaned_inflight_runs_skips_rows_when_error_status_is_not_persisted():
"""Startup recovery must not report a row as recovered if the error update failed."""
store = FailingStatusRunStore()
await store.put("running-run", thread_id="thread-1", status="running", created_at="2026-01-01T00:00:00+00:00")
manager = RunManager(
store=store,
persistence_retry_policy=PersistenceRetryPolicy(max_attempts=2, initial_delay=0),
)
recovered = await manager.reconcile_orphaned_inflight_runs(
error="Gateway restarted before this run reached a durable final state.",
before="2026-01-01T00:00:01+00:00",
)
stored = await store.get("running-run")
assert recovered == []
assert stored["status"] == "running"
assert store.status_update_attempts == 2
@pytest.mark.anyio @pytest.mark.anyio
async def test_cancel_not_inflight(manager: RunManager): async def test_cancel_not_inflight(manager: RunManager):
"""Cancelling a completed run should return False.""" """Cancelling a completed run should return False."""
+47 -2
View File
@@ -52,6 +52,9 @@ class _CustomRunStoreWithoutProgress(RunStore):
async def list_pending(self, *args, **kwargs): async def list_pending(self, *args, **kwargs):
return [] return []
async def list_inflight(self, *args, **kwargs):
return []
async def aggregate_tokens_by_thread(self, *args, **kwargs): async def aggregate_tokens_by_thread(self, *args, **kwargs):
return {} return {}
@@ -75,6 +78,19 @@ class TestRunRepository:
assert row["status"] == "pending" assert row["status"] == "pending"
await _cleanup() await _cleanup()
@pytest.mark.anyio
async def test_put_is_idempotent_for_retried_writes(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", assistant_id="old-agent", status="pending")
await repo.put("r1", thread_id="t1", assistant_id="new-agent", status="running", error="retry")
row = await repo.get("r1")
assert row["assistant_id"] == "new-agent"
assert row["status"] == "running"
assert row["error"] == "retry"
await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_missing_returns_none(self, tmp_path): async def test_get_missing_returns_none(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
@@ -85,11 +101,19 @@ class TestRunRepository:
async def test_update_status(self, tmp_path): async def test_update_status(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1") await repo.put("r1", thread_id="t1")
await repo.update_status("r1", "running") updated = await repo.update_status("r1", "running")
row = await repo.get("r1") row = await repo.get("r1")
assert updated is True
assert row["status"] == "running" assert row["status"] == "running"
await _cleanup() await _cleanup()
@pytest.mark.anyio
async def test_update_status_returns_false_for_missing_row(self, tmp_path):
repo = await _make_repo(tmp_path)
updated = await repo.update_status("missing", "error", error="lost")
assert updated is False
await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_update_status_with_error(self, tmp_path): async def test_update_status_with_error(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
@@ -146,11 +170,24 @@ class TestRunRepository:
assert all(r["status"] == "pending" for r in pending) assert all(r["status"] == "pending" for r in pending)
await _cleanup() await _cleanup()
@pytest.mark.anyio
async def test_list_inflight_returns_pending_and_running_before_cutoff(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("pending-old", thread_id="t1", status="pending", created_at="2026-01-01T00:00:00+00:00")
await repo.put("running-old", thread_id="t1", status="running", created_at="2026-01-01T00:00:01+00:00")
await repo.put("success-old", thread_id="t1", status="success", created_at="2026-01-01T00:00:02+00:00")
await repo.put("pending-new", thread_id="t1", status="pending", created_at="2026-01-01T00:00:03+00:00")
inflight = await repo.list_inflight(before="2026-01-01T00:00:02+00:00")
assert [row["run_id"] for row in inflight] == ["pending-old", "running-old"]
await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_update_run_completion(self, tmp_path): async def test_update_run_completion(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running") await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_completion( updated = await repo.update_run_completion(
"r1", "r1",
status="success", status="success",
total_input_tokens=100, total_input_tokens=100,
@@ -165,6 +202,7 @@ class TestRunRepository:
first_human_message="What is the meaning?", first_human_message="What is the meaning?",
) )
row = await repo.get("r1") row = await repo.get("r1")
assert updated is True
assert row["status"] == "success" assert row["status"] == "success"
assert row["total_tokens"] == 150 assert row["total_tokens"] == 150
assert row["llm_call_count"] == 2 assert row["llm_call_count"] == 2
@@ -174,6 +212,13 @@ class TestRunRepository:
assert row["first_human_message"] == "What is the meaning?" assert row["first_human_message"] == "What is the meaning?"
await _cleanup() await _cleanup()
@pytest.mark.anyio
async def test_update_run_completion_returns_false_for_missing_row(self, tmp_path):
repo = await _make_repo(tmp_path)
updated = await repo.update_run_completion("missing", status="error", total_tokens=1)
assert updated is False
await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_metadata_preserved(self, tmp_path): async def test_metadata_preserved(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)