Files
deer-flow/backend/tests/test_gateway_run_recovery.py
T
Xinmin Zeng 268fdd6968 fix(gateway): drain in-flight runs before closing checkpointer on shutdown (#3381)
* fix(gateway): drain in-flight runs before closing checkpointer on shutdown

Chat runs execute in fire-and-forget background asyncio tasks that write
checkpoints through a shared checkpointer. On shutdown, langgraph_runtime's
AsyncExitStack tore down the checkpointer's postgres connection pool while
those run tasks were still mid-graph. langgraph's
AsyncPregelLoop._checkpointer_put_after_previous then ran its
`finally: await checkpointer.aput(...)` against the closed pool, raising
psycopg_pool.PoolClosed. Because that put runs in a langgraph-internal task
(not on run_agent's call stack), run_agent's try/except cannot catch it and it
surfaces as "unhandled exception during asyncio.run() shutdown".

Add RunManager.shutdown() to cancel and bounded-await all in-flight runs, and
call it from langgraph_runtime BEFORE the AsyncExitStack closes the
checkpointer, so the final checkpoint write lands while the pool is still open.
The drain is bounded by a timeout so a stuck run cannot hang worker shutdown,
and is shielded so a second shutdown signal cannot abandon it mid-drain and
reopen the race.

Closes #3373

* fix(gateway): address review — preserve completed-run status, bound drain persistence

Addresses Copilot review on #3381:

- RunManager.shutdown(): decide run status AFTER the drain. Under the lock it
  now only requests cancellation; after asyncio.wait it marks/persists
  `interrupted` only for runs still pending or ended cancelled. A run that
  completes (e.g. `success`) during the drain window keeps its real terminal
  status instead of being unconditionally overwritten.
- Bound the trailing status persistence within the timeout budget
  (deadline = loop.time()+timeout; gather wrapped in asyncio.wait_for) so a slow
  store backing off under DB pressure cannot push shutdown past the deadline.
- deps: use asyncio.create_task instead of asyncio.ensure_future.
- tests: wait deterministically for the run to be in-flight (poll the first
  checkpoint) instead of a fixed sleep; init shutdown_calls explicitly in the
  recovery test double; add regression test asserting a run completing during
  the drain keeps its status (in memory and in the store).

* fix(gateway): address maintainer review — surface failed drain persists, clarify timeout constant

Addresses @WillemJiang review on #3381:

- shutdown(): inspect the gather result of the trailing interrupted-status
  persistence. _persist_status is best-effort (it catches + logs its own
  failure with exc_info and returns False, so it never raises out of the
  gather), but the aggregate result was never checked — a partial failure had
  no shutdown-level visibility. Now any escaped Exception is logged, and any
  False (a persist that did not confirm) is logged with the run_id. Added
  regression test test_shutdown_surfaces_failed_interrupted_persist.
- deps: clarify the _RUN_DRAIN_TIMEOUT_SECONDS comment — state the actual value
  of _SHUTDOWN_HOOK_TIMEOUT_SECONDS (5.0s) and that both count toward the
  lifespan shutdown window. Kept as two separate constants (independent teardown
  steps that may diverge) rather than one shared "must match" value.
- Verified no other test fake needs the shutdown stub: _FakeRunManager in
  test_worker_langfuse_metadata.py is a run_agent() argument (worker path),
  never injected into langgraph_runtime, so it never receives shutdown().
2026-06-07 11:24:30 +08:00

134 lines
5.9 KiB
Python

"""Gateway startup recovery for stale persisted runs."""
from __future__ import annotations
from contextlib import asynccontextmanager
from types import SimpleNamespace
import pytest
from fastapi import FastAPI
import deerflow.runtime as runtime_module
from app.gateway import deps as gateway_deps
from deerflow.persistence import engine as engine_module
from deerflow.persistence import thread_meta as thread_meta_module
from deerflow.runtime.checkpointer import async_provider as checkpointer_module
from deerflow.runtime.events import store as event_store_module
@asynccontextmanager
async def _fake_context(value):
yield value
class _FakeRunManager:
"""RunManager double that records startup reconciliation calls."""
instances: list[_FakeRunManager] = []
recovered_runs = [SimpleNamespace(run_id="run-1", thread_id="thread-1")]
latest_by_thread: dict[str, list[SimpleNamespace]] = {}
def __init__(self, *, store):
self.store = store
self.reconcile_calls: list[dict] = []
self.list_by_thread_calls: list[dict] = []
self.shutdown_calls: int = 0
_FakeRunManager.instances.append(self)
async def reconcile_orphaned_inflight_runs(self, *, error: str, before: str | None = None):
self.reconcile_calls.append({"error": error, "before": before})
return self.recovered_runs
async def list_by_thread(self, thread_id: str, *, user_id=None, limit: int = 100):
self.list_by_thread_calls.append({"thread_id": thread_id, "user_id": user_id, "limit": limit})
return self.latest_by_thread.get(thread_id, self.recovered_runs[:limit])
async def shutdown(self, *, timeout: float = 5.0) -> None:
# No in-flight tasks in these startup-recovery tests; langgraph_runtime
# drains the manager on teardown, so the double must accept the call.
self.shutdown_calls += 1
class _FakeThreadStore:
def __init__(self) -> None:
self.status_updates: list[tuple[str, str, str | None]] = []
async def update_status(self, thread_id: str, status: str, *, user_id=None) -> None:
self.status_updates.append((thread_id, status, user_id))
@pytest.mark.anyio
async def test_sqlite_runtime_reconciles_orphaned_runs_on_startup(monkeypatch):
"""SQLite startup should recover stale active runs before serving requests."""
app = FastAPI()
config = SimpleNamespace(
database=SimpleNamespace(backend="sqlite"),
run_events=SimpleNamespace(backend="memory"),
)
thread_store = _FakeThreadStore()
_FakeRunManager.instances.clear()
_FakeRunManager.recovered_runs = [SimpleNamespace(run_id="run-1", thread_id="thread-1")]
_FakeRunManager.latest_by_thread = {}
async def fake_init_engine_from_config(_database):
return None
async def fake_close_engine():
return None
monkeypatch.setattr(engine_module, "init_engine_from_config", fake_init_engine_from_config)
monkeypatch.setattr(engine_module, "get_session_factory", lambda: None)
monkeypatch.setattr(engine_module, "close_engine", fake_close_engine)
monkeypatch.setattr(runtime_module, "make_stream_bridge", lambda _config: _fake_context(object()))
monkeypatch.setattr(checkpointer_module, "make_checkpointer", lambda _config: _fake_context(object()))
monkeypatch.setattr(runtime_module, "make_store", lambda _config: _fake_context(object()))
monkeypatch.setattr(thread_meta_module, "make_thread_store", lambda _sf, _store: thread_store)
monkeypatch.setattr(event_store_module, "make_run_event_store", lambda _config: object())
monkeypatch.setattr(gateway_deps, "RunManager", _FakeRunManager)
async with gateway_deps.langgraph_runtime(app, config):
pass
assert len(_FakeRunManager.instances) == 1
assert _FakeRunManager.instances[0].reconcile_calls
assert _FakeRunManager.instances[0].reconcile_calls[0]["error"]
assert _FakeRunManager.instances[0].list_by_thread_calls == [{"thread_id": "thread-1", "user_id": None, "limit": 1}]
assert thread_store.status_updates == [("thread-1", "error", None)]
@pytest.mark.anyio
async def test_sqlite_runtime_does_not_mark_thread_error_when_newer_run_is_success(monkeypatch):
"""Startup recovery should not let an old orphaned run overwrite a newer terminal thread state."""
app = FastAPI()
config = SimpleNamespace(
database=SimpleNamespace(backend="sqlite"),
run_events=SimpleNamespace(backend="memory"),
)
thread_store = _FakeThreadStore()
_FakeRunManager.instances.clear()
_FakeRunManager.recovered_runs = [SimpleNamespace(run_id="old-running", thread_id="thread-1")]
_FakeRunManager.latest_by_thread = {"thread-1": [SimpleNamespace(run_id="newer-success", thread_id="thread-1", status="success")]}
async def fake_init_engine_from_config(_database):
return None
async def fake_close_engine():
return None
monkeypatch.setattr(engine_module, "init_engine_from_config", fake_init_engine_from_config)
monkeypatch.setattr(engine_module, "get_session_factory", lambda: None)
monkeypatch.setattr(engine_module, "close_engine", fake_close_engine)
monkeypatch.setattr(runtime_module, "make_stream_bridge", lambda _config: _fake_context(object()))
monkeypatch.setattr(checkpointer_module, "make_checkpointer", lambda _config: _fake_context(object()))
monkeypatch.setattr(runtime_module, "make_store", lambda _config: _fake_context(object()))
monkeypatch.setattr(thread_meta_module, "make_thread_store", lambda _sf, _store: thread_store)
monkeypatch.setattr(event_store_module, "make_run_event_store", lambda _config: object())
monkeypatch.setattr(gateway_deps, "RunManager", _FakeRunManager)
async with gateway_deps.langgraph_runtime(app, config):
pass
assert len(_FakeRunManager.instances) == 1
assert _FakeRunManager.instances[0].list_by_thread_calls == [{"thread_id": "thread-1", "user_id": None, "limit": 1}]
assert thread_store.status_updates == []