fix: harden run finalization persistence (#3155)

* fix: harden run finalization persistence

* style: format gateway recovery test

* fix: align run repository return types

* fix: harden completion recovery follow-up
This commit is contained in:
AochenShen99
2026-05-23 00:09:06 +08:00
committed by GitHub
parent f0bae28636
commit 66d6a6a4e8
8 changed files with 755 additions and 56 deletions
+127
View File
@@ -0,0 +1,127 @@
"""Gateway startup recovery for stale persisted runs."""
from __future__ import annotations
from contextlib import asynccontextmanager
from types import SimpleNamespace
import pytest
from fastapi import FastAPI
import deerflow.runtime as runtime_module
from app.gateway import deps as gateway_deps
from deerflow.persistence import engine as engine_module
from deerflow.persistence import thread_meta as thread_meta_module
from deerflow.runtime.checkpointer import async_provider as checkpointer_module
from deerflow.runtime.events import store as event_store_module
@asynccontextmanager
async def _fake_context(value):
yield value
class _FakeRunManager:
"""RunManager double that records startup reconciliation calls."""
instances: list[_FakeRunManager] = []
recovered_runs = [SimpleNamespace(run_id="run-1", thread_id="thread-1")]
latest_by_thread: dict[str, list[SimpleNamespace]] = {}
def __init__(self, *, store):
self.store = store
self.reconcile_calls: list[dict] = []
self.list_by_thread_calls: list[dict] = []
_FakeRunManager.instances.append(self)
async def reconcile_orphaned_inflight_runs(self, *, error: str, before: str | None = None):
self.reconcile_calls.append({"error": error, "before": before})
return self.recovered_runs
async def list_by_thread(self, thread_id: str, *, user_id=None, limit: int = 100):
self.list_by_thread_calls.append({"thread_id": thread_id, "user_id": user_id, "limit": limit})
return self.latest_by_thread.get(thread_id, self.recovered_runs[:limit])
class _FakeThreadStore:
def __init__(self) -> None:
self.status_updates: list[tuple[str, str, str | None]] = []
async def update_status(self, thread_id: str, status: str, *, user_id=None) -> None:
self.status_updates.append((thread_id, status, user_id))
@pytest.mark.anyio
async def test_sqlite_runtime_reconciles_orphaned_runs_on_startup(monkeypatch):
"""SQLite startup should recover stale active runs before serving requests."""
app = FastAPI()
config = SimpleNamespace(
database=SimpleNamespace(backend="sqlite"),
run_events=SimpleNamespace(backend="memory"),
)
thread_store = _FakeThreadStore()
_FakeRunManager.instances.clear()
_FakeRunManager.recovered_runs = [SimpleNamespace(run_id="run-1", thread_id="thread-1")]
_FakeRunManager.latest_by_thread = {}
async def fake_init_engine_from_config(_database):
return None
async def fake_close_engine():
return None
monkeypatch.setattr(engine_module, "init_engine_from_config", fake_init_engine_from_config)
monkeypatch.setattr(engine_module, "get_session_factory", lambda: None)
monkeypatch.setattr(engine_module, "close_engine", fake_close_engine)
monkeypatch.setattr(runtime_module, "make_stream_bridge", lambda _config: _fake_context(object()))
monkeypatch.setattr(checkpointer_module, "make_checkpointer", lambda _config: _fake_context(object()))
monkeypatch.setattr(runtime_module, "make_store", lambda _config: _fake_context(object()))
monkeypatch.setattr(thread_meta_module, "make_thread_store", lambda _sf, _store: thread_store)
monkeypatch.setattr(event_store_module, "make_run_event_store", lambda _config: object())
monkeypatch.setattr(gateway_deps, "RunManager", _FakeRunManager)
async with gateway_deps.langgraph_runtime(app, config):
pass
assert len(_FakeRunManager.instances) == 1
assert _FakeRunManager.instances[0].reconcile_calls
assert _FakeRunManager.instances[0].reconcile_calls[0]["error"]
assert _FakeRunManager.instances[0].list_by_thread_calls == [{"thread_id": "thread-1", "user_id": None, "limit": 1}]
assert thread_store.status_updates == [("thread-1", "error", None)]
@pytest.mark.anyio
async def test_sqlite_runtime_does_not_mark_thread_error_when_newer_run_is_success(monkeypatch):
"""Startup recovery should not let an old orphaned run overwrite a newer terminal thread state."""
app = FastAPI()
config = SimpleNamespace(
database=SimpleNamespace(backend="sqlite"),
run_events=SimpleNamespace(backend="memory"),
)
thread_store = _FakeThreadStore()
_FakeRunManager.instances.clear()
_FakeRunManager.recovered_runs = [SimpleNamespace(run_id="old-running", thread_id="thread-1")]
_FakeRunManager.latest_by_thread = {"thread-1": [SimpleNamespace(run_id="newer-success", thread_id="thread-1", status="success")]}
async def fake_init_engine_from_config(_database):
return None
async def fake_close_engine():
return None
monkeypatch.setattr(engine_module, "init_engine_from_config", fake_init_engine_from_config)
monkeypatch.setattr(engine_module, "get_session_factory", lambda: None)
monkeypatch.setattr(engine_module, "close_engine", fake_close_engine)
monkeypatch.setattr(runtime_module, "make_stream_bridge", lambda _config: _fake_context(object()))
monkeypatch.setattr(checkpointer_module, "make_checkpointer", lambda _config: _fake_context(object()))
monkeypatch.setattr(runtime_module, "make_store", lambda _config: _fake_context(object()))
monkeypatch.setattr(thread_meta_module, "make_thread_store", lambda _sf, _store: thread_store)
monkeypatch.setattr(event_store_module, "make_run_event_store", lambda _config: object())
monkeypatch.setattr(gateway_deps, "RunManager", _FakeRunManager)
async with gateway_deps.langgraph_runtime(app, config):
pass
assert len(_FakeRunManager.instances) == 1
assert _FakeRunManager.instances[0].list_by_thread_calls == [{"thread_id": "thread-1", "user_id": None, "limit": 1}]
assert thread_store.status_updates == []
+240
View File
@@ -1,10 +1,15 @@
"""Tests for RunManager."""
import logging
import re
import sqlite3
from typing import Any
import pytest
from sqlalchemy.exc import DatabaseError as SQLAlchemyDatabaseError
from deerflow.runtime import DisconnectMode, RunManager, RunStatus
from deerflow.runtime.runs.manager import PersistenceRetryPolicy
from deerflow.runtime.runs.store.memory import MemoryRunStore
ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
@@ -15,6 +20,92 @@ def manager() -> RunManager:
return RunManager()
class FlakyStatusRunStore(MemoryRunStore):
"""Memory run store that simulates transient SQLite status-write failures."""
def __init__(self, *, status_failures: int) -> None:
super().__init__()
self.status_failures = status_failures
self.status_update_attempts = 0
async def update_status(self, run_id, status, *, error=None):
self.status_update_attempts += 1
if self.status_failures > 0:
self.status_failures -= 1
raise sqlite3.OperationalError("database is locked")
return await super().update_status(run_id, status, error=error)
class MissingRowStatusRunStore(MemoryRunStore):
"""Memory run store that reports a missing row for status updates."""
async def update_status(self, run_id, status, *, error=None):
await super().update_status(run_id, status, error=error)
return False
class PermanentStatusRunStore(MemoryRunStore):
"""Memory run store that simulates a permanent SQLAlchemy write failure."""
def __init__(self) -> None:
super().__init__()
self.status_update_attempts = 0
async def update_status(self, run_id, status, *, error=None):
self.status_update_attempts += 1
raise SQLAlchemyDatabaseError(
"UPDATE runs SET status = :status WHERE run_id = :run_id",
{"status": status, "run_id": run_id},
sqlite3.DatabaseError("no such table: runs"),
)
class FailingStatusRunStore(MemoryRunStore):
"""Memory run store that always fails status updates."""
def __init__(self) -> None:
super().__init__()
self.status_update_attempts = 0
async def update_status(self, run_id, status, *, error=None):
self.status_update_attempts += 1
raise sqlite3.OperationalError("database is locked")
class MissingCompletionRunStore(MemoryRunStore):
"""Memory run store that reports one missing row for completion updates."""
def __init__(self) -> None:
super().__init__()
self.completion_update_attempts = 0
async def update_run_completion(self, run_id, *, status, **kwargs):
self.completion_update_attempts += 1
if self.completion_update_attempts == 1:
return False
return await super().update_run_completion(run_id, status=status, **kwargs)
class AlwaysMissingCompletionRunStore(MemoryRunStore):
"""Memory run store that keeps reporting missing rows for completion updates."""
def __init__(self) -> None:
super().__init__()
self.completion_update_attempts = 0
async def update_run_completion(self, run_id, *, status, **kwargs):
self.completion_update_attempts += 1
return False
async def _stored_statuses(store: MemoryRunStore, *run_ids: str) -> dict[str, Any]:
rows = {}
for run_id in run_ids:
row = await store.get(run_id)
rows[run_id] = row["status"] if row else None
return rows
@pytest.mark.anyio
async def test_create_and_get(manager: RunManager):
"""Created run should be retrievable with new fields."""
@@ -80,6 +171,155 @@ async def test_cancel_persists_interrupted_status_to_store():
assert stored["status"] == "interrupted"
@pytest.mark.anyio
async def test_status_persistence_retries_transient_sqlite_lock():
"""Transient SQLite lock errors should not leave a final status stale."""
store = FlakyStatusRunStore(status_failures=2)
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.running)
await manager.set_status(record.run_id, RunStatus.success)
stored = await store.get(record.run_id)
assert stored is not None
assert stored["status"] == "success"
assert store.status_update_attempts >= 4
@pytest.mark.anyio
async def test_status_persistence_recreates_missing_store_row():
"""A final status update should recreate a run row if initial persistence was lost."""
store = MissingRowStatusRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await store.delete(record.run_id)
await manager.set_status(record.run_id, RunStatus.error, error="boom")
stored = await store.get(record.run_id)
assert stored is not None
assert stored["status"] == "error"
assert stored["error"] == "boom"
@pytest.mark.anyio
async def test_status_persistence_does_not_retry_permanent_sqlalchemy_errors():
"""Permanent SQLAlchemy failures should not be retried as SQLite pressure."""
store = PermanentStatusRunStore()
manager = RunManager(
store=store,
persistence_retry_policy=PersistenceRetryPolicy(max_attempts=5, initial_delay=0),
)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.error, error="boom")
assert store.status_update_attempts == 1
@pytest.mark.anyio
async def test_completion_persistence_recreates_missing_store_row():
"""Completion updates should recreate a missing row and persist final counters."""
store = MissingCompletionRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.running)
await manager.set_status(record.run_id, RunStatus.success)
await store.delete(record.run_id)
await manager.update_run_completion(
record.run_id,
status="success",
total_tokens=42,
llm_call_count=2,
last_ai_message="done",
)
stored = await store.get(record.run_id)
assert stored is not None
assert stored["status"] == "success"
assert stored["total_tokens"] == 42
assert stored["llm_call_count"] == 2
assert stored["last_ai_message"] == "done"
assert store.completion_update_attempts == 2
@pytest.mark.anyio
async def test_completion_persistence_warns_when_recreated_row_still_missing(caplog):
"""A second zero-row completion update after recreation should not be silent."""
store = AlwaysMissingCompletionRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.success)
caplog.set_level(logging.WARNING, logger="deerflow.runtime.runs.manager")
await manager.update_run_completion(record.run_id, status="success", total_tokens=42)
assert store.completion_update_attempts == 2
assert "affected no rows after row recreation" in caplog.text
@pytest.mark.anyio
async def test_reconcile_orphaned_inflight_runs_marks_stale_rows_error():
"""Startup recovery should turn persisted active rows into explicit errors."""
store = MemoryRunStore()
await store.put("pending-run", thread_id="thread-1", status="pending", created_at="2026-01-01T00:00:00+00:00")
await store.put("running-run", thread_id="thread-1", status="running", created_at="2026-01-01T00:00:01+00:00")
await store.put("success-run", thread_id="thread-1", status="success", created_at="2026-01-01T00:00:02+00:00")
manager = RunManager(store=store)
recovered = await manager.reconcile_orphaned_inflight_runs(
error="Gateway restarted before this run reached a durable final state.",
before="2026-01-01T00:00:02+00:00",
)
assert {record.run_id for record in recovered} == {"pending-run", "running-run"}
assert await _stored_statuses(store, "pending-run", "running-run", "success-run") == {
"pending-run": "error",
"running-run": "error",
"success-run": "success",
}
@pytest.mark.anyio
async def test_reconcile_orphaned_inflight_runs_skips_live_local_run():
"""Startup recovery should not mark an active row orphaned when this worker owns it."""
store = MemoryRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.running)
recovered = await manager.reconcile_orphaned_inflight_runs(
error="Gateway restarted before this run reached a durable final state.",
)
stored = await store.get(record.run_id)
assert recovered == []
assert stored["status"] == "running"
@pytest.mark.anyio
async def test_reconcile_orphaned_inflight_runs_skips_rows_when_error_status_is_not_persisted():
"""Startup recovery must not report a row as recovered if the error update failed."""
store = FailingStatusRunStore()
await store.put("running-run", thread_id="thread-1", status="running", created_at="2026-01-01T00:00:00+00:00")
manager = RunManager(
store=store,
persistence_retry_policy=PersistenceRetryPolicy(max_attempts=2, initial_delay=0),
)
recovered = await manager.reconcile_orphaned_inflight_runs(
error="Gateway restarted before this run reached a durable final state.",
before="2026-01-01T00:00:01+00:00",
)
stored = await store.get("running-run")
assert recovered == []
assert stored["status"] == "running"
assert store.status_update_attempts == 2
@pytest.mark.anyio
async def test_cancel_not_inflight(manager: RunManager):
"""Cancelling a completed run should return False."""
+47 -2
View File
@@ -52,6 +52,9 @@ class _CustomRunStoreWithoutProgress(RunStore):
async def list_pending(self, *args, **kwargs):
return []
async def list_inflight(self, *args, **kwargs):
return []
async def aggregate_tokens_by_thread(self, *args, **kwargs):
return {}
@@ -75,6 +78,19 @@ class TestRunRepository:
assert row["status"] == "pending"
await _cleanup()
@pytest.mark.anyio
async def test_put_is_idempotent_for_retried_writes(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", assistant_id="old-agent", status="pending")
await repo.put("r1", thread_id="t1", assistant_id="new-agent", status="running", error="retry")
row = await repo.get("r1")
assert row["assistant_id"] == "new-agent"
assert row["status"] == "running"
assert row["error"] == "retry"
await _cleanup()
@pytest.mark.anyio
async def test_get_missing_returns_none(self, tmp_path):
repo = await _make_repo(tmp_path)
@@ -85,11 +101,19 @@ class TestRunRepository:
async def test_update_status(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1")
await repo.update_status("r1", "running")
updated = await repo.update_status("r1", "running")
row = await repo.get("r1")
assert updated is True
assert row["status"] == "running"
await _cleanup()
@pytest.mark.anyio
async def test_update_status_returns_false_for_missing_row(self, tmp_path):
repo = await _make_repo(tmp_path)
updated = await repo.update_status("missing", "error", error="lost")
assert updated is False
await _cleanup()
@pytest.mark.anyio
async def test_update_status_with_error(self, tmp_path):
repo = await _make_repo(tmp_path)
@@ -146,11 +170,24 @@ class TestRunRepository:
assert all(r["status"] == "pending" for r in pending)
await _cleanup()
@pytest.mark.anyio
async def test_list_inflight_returns_pending_and_running_before_cutoff(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("pending-old", thread_id="t1", status="pending", created_at="2026-01-01T00:00:00+00:00")
await repo.put("running-old", thread_id="t1", status="running", created_at="2026-01-01T00:00:01+00:00")
await repo.put("success-old", thread_id="t1", status="success", created_at="2026-01-01T00:00:02+00:00")
await repo.put("pending-new", thread_id="t1", status="pending", created_at="2026-01-01T00:00:03+00:00")
inflight = await repo.list_inflight(before="2026-01-01T00:00:02+00:00")
assert [row["run_id"] for row in inflight] == ["pending-old", "running-old"]
await _cleanup()
@pytest.mark.anyio
async def test_update_run_completion(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_completion(
updated = await repo.update_run_completion(
"r1",
status="success",
total_input_tokens=100,
@@ -165,6 +202,7 @@ class TestRunRepository:
first_human_message="What is the meaning?",
)
row = await repo.get("r1")
assert updated is True
assert row["status"] == "success"
assert row["total_tokens"] == 150
assert row["llm_call_count"] == 2
@@ -174,6 +212,13 @@ class TestRunRepository:
assert row["first_human_message"] == "What is the meaning?"
await _cleanup()
@pytest.mark.anyio
async def test_update_run_completion_returns_false_for_missing_row(self, tmp_path):
repo = await _make_repo(tmp_path)
updated = await repo.update_run_completion("missing", status="error", total_tokens=1)
assert updated is False
await _cleanup()
@pytest.mark.anyio
async def test_metadata_preserved(self, tmp_path):
repo = await _make_repo(tmp_path)