mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 17:35:57 +00:00
268fdd6968
* fix(gateway): drain in-flight runs before closing checkpointer on shutdown Chat runs execute in fire-and-forget background asyncio tasks that write checkpoints through a shared checkpointer. On shutdown, langgraph_runtime's AsyncExitStack tore down the checkpointer's postgres connection pool while those run tasks were still mid-graph. langgraph's AsyncPregelLoop._checkpointer_put_after_previous then ran its `finally: await checkpointer.aput(...)` against the closed pool, raising psycopg_pool.PoolClosed. Because that put runs in a langgraph-internal task (not on run_agent's call stack), run_agent's try/except cannot catch it and it surfaces as "unhandled exception during asyncio.run() shutdown". Add RunManager.shutdown() to cancel and bounded-await all in-flight runs, and call it from langgraph_runtime BEFORE the AsyncExitStack closes the checkpointer, so the final checkpoint write lands while the pool is still open. The drain is bounded by a timeout so a stuck run cannot hang worker shutdown, and is shielded so a second shutdown signal cannot abandon it mid-drain and reopen the race. Closes #3373 * fix(gateway): address review — preserve completed-run status, bound drain persistence Addresses Copilot review on #3381: - RunManager.shutdown(): decide run status AFTER the drain. Under the lock it now only requests cancellation; after asyncio.wait it marks/persists `interrupted` only for runs still pending or ended cancelled. A run that completes (e.g. `success`) during the drain window keeps its real terminal status instead of being unconditionally overwritten. - Bound the trailing status persistence within the timeout budget (deadline = loop.time()+timeout; gather wrapped in asyncio.wait_for) so a slow store backing off under DB pressure cannot push shutdown past the deadline. - deps: use asyncio.create_task instead of asyncio.ensure_future. - tests: wait deterministically for the run to be in-flight (poll the first checkpoint) instead of a fixed sleep; init shutdown_calls explicitly in the recovery test double; add regression test asserting a run completing during the drain keeps its status (in memory and in the store). * fix(gateway): address maintainer review — surface failed drain persists, clarify timeout constant Addresses @WillemJiang review on #3381: - shutdown(): inspect the gather result of the trailing interrupted-status persistence. _persist_status is best-effort (it catches + logs its own failure with exc_info and returns False, so it never raises out of the gather), but the aggregate result was never checked — a partial failure had no shutdown-level visibility. Now any escaped Exception is logged, and any False (a persist that did not confirm) is logged with the run_id. Added regression test test_shutdown_surfaces_failed_interrupted_persist. - deps: clarify the _RUN_DRAIN_TIMEOUT_SECONDS comment — state the actual value of _SHUTDOWN_HOOK_TIMEOUT_SECONDS (5.0s) and that both count toward the lifespan shutdown window. Kept as two separate constants (independent teardown steps that may diverge) rather than one shared "must match" value. - Verified no other test fake needs the shutdown stub: _FakeRunManager in test_worker_langfuse_metadata.py is a run_agent() argument (worker path), never injected into langgraph_runtime, so it never receives shutdown().
134 lines
5.9 KiB
Python
134 lines
5.9 KiB
Python
"""Gateway startup recovery for stale persisted runs."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from contextlib import asynccontextmanager
|
|
from types import SimpleNamespace
|
|
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
|
|
import deerflow.runtime as runtime_module
|
|
from app.gateway import deps as gateway_deps
|
|
from deerflow.persistence import engine as engine_module
|
|
from deerflow.persistence import thread_meta as thread_meta_module
|
|
from deerflow.runtime.checkpointer import async_provider as checkpointer_module
|
|
from deerflow.runtime.events import store as event_store_module
|
|
|
|
|
|
@asynccontextmanager
|
|
async def _fake_context(value):
|
|
yield value
|
|
|
|
|
|
class _FakeRunManager:
|
|
"""RunManager double that records startup reconciliation calls."""
|
|
|
|
instances: list[_FakeRunManager] = []
|
|
recovered_runs = [SimpleNamespace(run_id="run-1", thread_id="thread-1")]
|
|
latest_by_thread: dict[str, list[SimpleNamespace]] = {}
|
|
|
|
def __init__(self, *, store):
|
|
self.store = store
|
|
self.reconcile_calls: list[dict] = []
|
|
self.list_by_thread_calls: list[dict] = []
|
|
self.shutdown_calls: int = 0
|
|
_FakeRunManager.instances.append(self)
|
|
|
|
async def reconcile_orphaned_inflight_runs(self, *, error: str, before: str | None = None):
|
|
self.reconcile_calls.append({"error": error, "before": before})
|
|
return self.recovered_runs
|
|
|
|
async def list_by_thread(self, thread_id: str, *, user_id=None, limit: int = 100):
|
|
self.list_by_thread_calls.append({"thread_id": thread_id, "user_id": user_id, "limit": limit})
|
|
return self.latest_by_thread.get(thread_id, self.recovered_runs[:limit])
|
|
|
|
async def shutdown(self, *, timeout: float = 5.0) -> None:
|
|
# No in-flight tasks in these startup-recovery tests; langgraph_runtime
|
|
# drains the manager on teardown, so the double must accept the call.
|
|
self.shutdown_calls += 1
|
|
|
|
|
|
class _FakeThreadStore:
|
|
def __init__(self) -> None:
|
|
self.status_updates: list[tuple[str, str, str | None]] = []
|
|
|
|
async def update_status(self, thread_id: str, status: str, *, user_id=None) -> None:
|
|
self.status_updates.append((thread_id, status, user_id))
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_sqlite_runtime_reconciles_orphaned_runs_on_startup(monkeypatch):
|
|
"""SQLite startup should recover stale active runs before serving requests."""
|
|
app = FastAPI()
|
|
config = SimpleNamespace(
|
|
database=SimpleNamespace(backend="sqlite"),
|
|
run_events=SimpleNamespace(backend="memory"),
|
|
)
|
|
thread_store = _FakeThreadStore()
|
|
_FakeRunManager.instances.clear()
|
|
_FakeRunManager.recovered_runs = [SimpleNamespace(run_id="run-1", thread_id="thread-1")]
|
|
_FakeRunManager.latest_by_thread = {}
|
|
|
|
async def fake_init_engine_from_config(_database):
|
|
return None
|
|
|
|
async def fake_close_engine():
|
|
return None
|
|
|
|
monkeypatch.setattr(engine_module, "init_engine_from_config", fake_init_engine_from_config)
|
|
monkeypatch.setattr(engine_module, "get_session_factory", lambda: None)
|
|
monkeypatch.setattr(engine_module, "close_engine", fake_close_engine)
|
|
monkeypatch.setattr(runtime_module, "make_stream_bridge", lambda _config: _fake_context(object()))
|
|
monkeypatch.setattr(checkpointer_module, "make_checkpointer", lambda _config: _fake_context(object()))
|
|
monkeypatch.setattr(runtime_module, "make_store", lambda _config: _fake_context(object()))
|
|
monkeypatch.setattr(thread_meta_module, "make_thread_store", lambda _sf, _store: thread_store)
|
|
monkeypatch.setattr(event_store_module, "make_run_event_store", lambda _config: object())
|
|
monkeypatch.setattr(gateway_deps, "RunManager", _FakeRunManager)
|
|
|
|
async with gateway_deps.langgraph_runtime(app, config):
|
|
pass
|
|
|
|
assert len(_FakeRunManager.instances) == 1
|
|
assert _FakeRunManager.instances[0].reconcile_calls
|
|
assert _FakeRunManager.instances[0].reconcile_calls[0]["error"]
|
|
assert _FakeRunManager.instances[0].list_by_thread_calls == [{"thread_id": "thread-1", "user_id": None, "limit": 1}]
|
|
assert thread_store.status_updates == [("thread-1", "error", None)]
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_sqlite_runtime_does_not_mark_thread_error_when_newer_run_is_success(monkeypatch):
|
|
"""Startup recovery should not let an old orphaned run overwrite a newer terminal thread state."""
|
|
app = FastAPI()
|
|
config = SimpleNamespace(
|
|
database=SimpleNamespace(backend="sqlite"),
|
|
run_events=SimpleNamespace(backend="memory"),
|
|
)
|
|
thread_store = _FakeThreadStore()
|
|
_FakeRunManager.instances.clear()
|
|
_FakeRunManager.recovered_runs = [SimpleNamespace(run_id="old-running", thread_id="thread-1")]
|
|
_FakeRunManager.latest_by_thread = {"thread-1": [SimpleNamespace(run_id="newer-success", thread_id="thread-1", status="success")]}
|
|
|
|
async def fake_init_engine_from_config(_database):
|
|
return None
|
|
|
|
async def fake_close_engine():
|
|
return None
|
|
|
|
monkeypatch.setattr(engine_module, "init_engine_from_config", fake_init_engine_from_config)
|
|
monkeypatch.setattr(engine_module, "get_session_factory", lambda: None)
|
|
monkeypatch.setattr(engine_module, "close_engine", fake_close_engine)
|
|
monkeypatch.setattr(runtime_module, "make_stream_bridge", lambda _config: _fake_context(object()))
|
|
monkeypatch.setattr(checkpointer_module, "make_checkpointer", lambda _config: _fake_context(object()))
|
|
monkeypatch.setattr(runtime_module, "make_store", lambda _config: _fake_context(object()))
|
|
monkeypatch.setattr(thread_meta_module, "make_thread_store", lambda _sf, _store: thread_store)
|
|
monkeypatch.setattr(event_store_module, "make_run_event_store", lambda _config: object())
|
|
monkeypatch.setattr(gateway_deps, "RunManager", _FakeRunManager)
|
|
|
|
async with gateway_deps.langgraph_runtime(app, config):
|
|
pass
|
|
|
|
assert len(_FakeRunManager.instances) == 1
|
|
assert _FakeRunManager.instances[0].list_by_thread_calls == [{"thread_id": "thread-1", "user_id": None, "limit": 1}]
|
|
assert thread_store.status_updates == []
|