mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 09:25:57 +00:00
fix(gateway): drain in-flight runs before closing checkpointer on shutdown (#3381)
* fix(gateway): drain in-flight runs before closing checkpointer on shutdown Chat runs execute in fire-and-forget background asyncio tasks that write checkpoints through a shared checkpointer. On shutdown, langgraph_runtime's AsyncExitStack tore down the checkpointer's postgres connection pool while those run tasks were still mid-graph. langgraph's AsyncPregelLoop._checkpointer_put_after_previous then ran its `finally: await checkpointer.aput(...)` against the closed pool, raising psycopg_pool.PoolClosed. Because that put runs in a langgraph-internal task (not on run_agent's call stack), run_agent's try/except cannot catch it and it surfaces as "unhandled exception during asyncio.run() shutdown". Add RunManager.shutdown() to cancel and bounded-await all in-flight runs, and call it from langgraph_runtime BEFORE the AsyncExitStack closes the checkpointer, so the final checkpoint write lands while the pool is still open. The drain is bounded by a timeout so a stuck run cannot hang worker shutdown, and is shielded so a second shutdown signal cannot abandon it mid-drain and reopen the race. Closes #3373 * fix(gateway): address review — preserve completed-run status, bound drain persistence Addresses Copilot review on #3381: - RunManager.shutdown(): decide run status AFTER the drain. Under the lock it now only requests cancellation; after asyncio.wait it marks/persists `interrupted` only for runs still pending or ended cancelled. A run that completes (e.g. `success`) during the drain window keeps its real terminal status instead of being unconditionally overwritten. - Bound the trailing status persistence within the timeout budget (deadline = loop.time()+timeout; gather wrapped in asyncio.wait_for) so a slow store backing off under DB pressure cannot push shutdown past the deadline. - deps: use asyncio.create_task instead of asyncio.ensure_future. - tests: wait deterministically for the run to be in-flight (poll the first checkpoint) instead of a fixed sleep; init shutdown_calls explicitly in the recovery test double; add regression test asserting a run completing during the drain keeps its status (in memory and in the store). * fix(gateway): address maintainer review — surface failed drain persists, clarify timeout constant Addresses @WillemJiang review on #3381: - shutdown(): inspect the gather result of the trailing interrupted-status persistence. _persist_status is best-effort (it catches + logs its own failure with exc_info and returns False, so it never raises out of the gather), but the aggregate result was never checked — a partial failure had no shutdown-level visibility. Now any escaped Exception is logged, and any False (a persist that did not confirm) is logged with the run_id. Added regression test test_shutdown_surfaces_failed_interrupted_persist. - deps: clarify the _RUN_DRAIN_TIMEOUT_SECONDS comment — state the actual value of _SHUTDOWN_HOOK_TIMEOUT_SECONDS (5.0s) and that both count toward the lifespan shutdown window. Kept as two separate constants (independent teardown steps that may diverge) rather than one shared "must match" value. - Verified no other test fake needs the shutdown stub: _FakeRunManager in test_worker_langfuse_metadata.py is a run_agent() argument (worker path), never injected into langgraph_runtime, so it never receives shutdown().
This commit is contained in:
@@ -0,0 +1,353 @@
|
||||
"""Regression tests for graceful run-task drain on Gateway shutdown.
|
||||
|
||||
Guards bytedance/deer-flow issue #3373:
|
||||
|
||||
psycopg_pool.PoolClosed: the pool 'pool-1' is already closed
|
||||
|
||||
Root cause: chat runs are fire-and-forget background ``asyncio`` tasks
|
||||
(``app/gateway/services.py`` -> ``asyncio.create_task(run_agent(...))``) owned
|
||||
by nobody. On shutdown, ``langgraph_runtime``'s ``AsyncExitStack`` tore down the
|
||||
checkpointer's postgres pool while those tasks were still mid-graph. langgraph's
|
||||
``AsyncPregelLoop._checkpointer_put_after_previous`` then ran its
|
||||
``finally: await checkpointer.aput(...)`` against the already-closed pool.
|
||||
|
||||
Fix: ``RunManager.shutdown()`` cancels and *bounded*-awaits every in-flight run,
|
||||
and ``langgraph_runtime`` calls it BEFORE the ``AsyncExitStack`` closes the
|
||||
checkpointer — so the final checkpoint write lands while the pool is still open.
|
||||
The drain must stay bounded (a stuck run must not hang the worker, the
|
||||
precondition for the signal-reentrancy deadlock guarded by
|
||||
``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import operator
|
||||
from contextlib import asynccontextmanager, suppress
|
||||
from types import SimpleNamespace
|
||||
from typing import Annotated, TypedDict
|
||||
|
||||
import pytest
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from deerflow.runtime import RunManager, RunStatus
|
||||
|
||||
|
||||
# Module-level so langgraph's get_type_hints (which resolves annotations against
|
||||
# module globals under `from __future__ import annotations`) can see Annotated.
|
||||
class _CountState(TypedDict):
|
||||
count: Annotated[int, operator.add]
|
||||
|
||||
|
||||
class _CloseableSaver(InMemorySaver):
|
||||
"""InMemorySaver that fails writes once closed, like a closed pool."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._closed = False
|
||||
self.writes_after_close: list[str] = []
|
||||
|
||||
def close(self) -> None:
|
||||
self._closed = True
|
||||
|
||||
async def aput(self, *args, **kwargs):
|
||||
if self._closed:
|
||||
self.writes_after_close.append("aput")
|
||||
raise RuntimeError("checkpointer is closed")
|
||||
return await super().aput(*args, **kwargs)
|
||||
|
||||
async def aput_writes(self, *args, **kwargs):
|
||||
if self._closed:
|
||||
self.writes_after_close.append("aput_writes")
|
||||
raise RuntimeError("checkpointer is closed")
|
||||
return await super().aput_writes(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_cancels_and_awaits_inflight_run():
|
||||
"""shutdown() cancels the in-flight task, waits for it, marks it interrupted."""
|
||||
rm = RunManager()
|
||||
record = await rm.create("t-drain")
|
||||
await rm.set_status(record.run_id, RunStatus.running)
|
||||
|
||||
started = asyncio.Event()
|
||||
cancelled = asyncio.Event()
|
||||
|
||||
async def worker() -> None:
|
||||
try:
|
||||
started.set()
|
||||
await asyncio.Event().wait()
|
||||
except asyncio.CancelledError:
|
||||
cancelled.set()
|
||||
raise
|
||||
|
||||
record.task = asyncio.create_task(worker())
|
||||
try:
|
||||
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||
|
||||
await rm.shutdown(timeout=5.0)
|
||||
|
||||
assert record.task.done()
|
||||
assert cancelled.is_set()
|
||||
assert record.status == RunStatus.interrupted
|
||||
finally:
|
||||
if not record.task.done():
|
||||
record.task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await record.task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_is_bounded_when_run_ignores_cancellation():
|
||||
"""A run that swallows cancellation must not make shutdown() hang."""
|
||||
rm = RunManager()
|
||||
record = await rm.create("t-stubborn")
|
||||
await rm.set_status(record.run_id, RunStatus.running)
|
||||
|
||||
started = asyncio.Event()
|
||||
stop = asyncio.Event()
|
||||
|
||||
async def stubborn() -> None:
|
||||
started.set()
|
||||
while not stop.is_set():
|
||||
try:
|
||||
await asyncio.sleep(3600)
|
||||
except asyncio.CancelledError:
|
||||
if stop.is_set():
|
||||
raise
|
||||
# else: swallow — simulates a run stuck in slow cleanup
|
||||
|
||||
record.task = asyncio.create_task(stubborn())
|
||||
try:
|
||||
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
t0 = loop.time()
|
||||
await rm.shutdown(timeout=0.3)
|
||||
elapsed = loop.time() - t0
|
||||
|
||||
assert elapsed < 2.0, f"shutdown took {elapsed:.2f}s; drain is not bounded"
|
||||
finally:
|
||||
# cleanup the deliberately-stubborn task
|
||||
stop.set()
|
||||
record.task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await record.task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_is_noop_without_inflight_runs():
|
||||
"""shutdown() on an idle manager completes cleanly and is idempotent."""
|
||||
rm = RunManager()
|
||||
await rm.shutdown(timeout=1.0)
|
||||
# already-finished runs must not be re-cancelled or error out
|
||||
record = await rm.create("t-done")
|
||||
await rm.set_status(record.run_id, RunStatus.success)
|
||||
await rm.shutdown(timeout=1.0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_langgraph_runtime_drains_runs_before_closing_checkpointer(monkeypatch):
|
||||
"""The wiring order lock for #3373: drain in-flight runs, THEN close the pool.
|
||||
|
||||
Patches every ``langgraph_runtime`` collaborator down to trivial stand-ins so
|
||||
only the bootstrap/teardown ordering runs. The checkpointer probe records when
|
||||
its context manager exits (pool close); a ``RunManager.shutdown`` spy records
|
||||
when the drain happens. The drain MUST come first.
|
||||
"""
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.gateway.deps import langgraph_runtime
|
||||
|
||||
events: list[str] = []
|
||||
|
||||
@asynccontextmanager
|
||||
async def probe_checkpointer(_config):
|
||||
try:
|
||||
yield object()
|
||||
finally:
|
||||
events.append("checkpointer_closed")
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_stream_bridge(_config):
|
||||
yield object()
|
||||
|
||||
@asynccontextmanager
|
||||
async def fake_store(_config):
|
||||
yield object()
|
||||
|
||||
async def fake_init_engine(_db):
|
||||
return None
|
||||
|
||||
async def fake_close_engine():
|
||||
return None
|
||||
|
||||
async def spy_shutdown(self, *, timeout): # noqa: ANN001
|
||||
events.append("runs_drained")
|
||||
|
||||
monkeypatch.setattr("deerflow.runtime.checkpointer.async_provider.make_checkpointer", probe_checkpointer)
|
||||
monkeypatch.setattr("deerflow.runtime.make_stream_bridge", fake_stream_bridge)
|
||||
monkeypatch.setattr("deerflow.runtime.make_store", fake_store)
|
||||
monkeypatch.setattr("deerflow.persistence.engine.init_engine_from_config", fake_init_engine)
|
||||
monkeypatch.setattr("deerflow.persistence.engine.close_engine", fake_close_engine)
|
||||
monkeypatch.setattr("deerflow.persistence.engine.get_session_factory", lambda: None)
|
||||
monkeypatch.setattr("deerflow.runtime.events.store.make_run_event_store", lambda _cfg: object())
|
||||
monkeypatch.setattr("deerflow.persistence.thread_meta.make_thread_store", lambda _sf, _store: object())
|
||||
monkeypatch.setattr(RunManager, "shutdown", spy_shutdown, raising=False)
|
||||
|
||||
app = FastAPI()
|
||||
startup_config = SimpleNamespace(database=SimpleNamespace(backend="memory"), run_events=None)
|
||||
|
||||
async with langgraph_runtime(app, startup_config):
|
||||
pass
|
||||
|
||||
assert "runs_drained" in events, "langgraph_runtime never drained in-flight runs on shutdown"
|
||||
assert "checkpointer_closed" in events
|
||||
assert events.index("runs_drained") < events.index("checkpointer_closed"), f"runs must be drained before the checkpointer pool is closed; got order {events}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_flushes_real_graph_checkpoint_before_close():
|
||||
"""End-to-end #3373 guard with a REAL langgraph graph + checkpointer.
|
||||
|
||||
A real run is driven through ``graph.astream`` in a background task, then
|
||||
``RunManager.shutdown()`` drains it. The checkpointer raises once closed
|
||||
(mirroring ``psycopg_pool.PoolClosed``). Closing only happens AFTER the
|
||||
drain — as the gateway's AsyncExitStack does. The drain must let langgraph
|
||||
flush its final checkpoint while the checkpointer is still open, so no write
|
||||
lands against a closed checkpointer.
|
||||
|
||||
Unlike the unit/spy tests above, this exercises the real langgraph
|
||||
checkpoint-put machinery, so a future langgraph change that cancels (rather
|
||||
than awaits) its checkpoint-put task on executor exit would fail this test
|
||||
instead of silently regressing #3373.
|
||||
"""
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
|
||||
async def slow(_state: _CountState) -> dict:
|
||||
await asyncio.sleep(0.1)
|
||||
return {"count": 1}
|
||||
|
||||
saver = _CloseableSaver()
|
||||
builder = StateGraph(_CountState)
|
||||
for name in ("a", "b", "c"):
|
||||
builder.add_node(name, slow)
|
||||
builder.add_edge(START, "a")
|
||||
builder.add_edge("a", "b")
|
||||
builder.add_edge("b", "c")
|
||||
builder.add_edge("c", END)
|
||||
graph = builder.compile(checkpointer=saver)
|
||||
|
||||
rm = RunManager()
|
||||
record = await rm.create("t-e2e")
|
||||
await rm.set_status(record.run_id, RunStatus.running)
|
||||
thread_cfg = {"configurable": {"thread_id": "t-e2e"}}
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def run() -> None:
|
||||
started.set()
|
||||
async for _ in graph.astream({"count": 0}, config=thread_cfg):
|
||||
pass
|
||||
|
||||
record.task = asyncio.create_task(run())
|
||||
try:
|
||||
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||
|
||||
# Deterministically wait until the run is genuinely in-flight — poll for
|
||||
# the first persisted checkpoint instead of a fixed sleep (avoids CI
|
||||
# flakiness on slow runners / under event-loop contention).
|
||||
async def _await_first_checkpoint() -> None:
|
||||
while (await saver.aget_tuple(thread_cfg)) is None:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
await asyncio.wait_for(_await_first_checkpoint(), timeout=5.0)
|
||||
|
||||
# The fix: drain while the checkpointer is still open ...
|
||||
await rm.shutdown(timeout=5.0)
|
||||
# ... and only then close it (mirrors langgraph_runtime's ExitStack).
|
||||
saver.close()
|
||||
|
||||
assert saver.writes_after_close == [], f"a checkpoint write raced a closed checkpointer: {saver.writes_after_close}"
|
||||
# The final checkpoint landed before close.
|
||||
snapshot = await saver.aget_tuple(thread_cfg)
|
||||
assert snapshot is not None
|
||||
finally:
|
||||
if not record.task.done():
|
||||
record.task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await record.task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_preserves_status_of_run_completed_during_drain():
|
||||
"""A run that finishes (e.g. success) during the drain window must keep its
|
||||
real terminal status — shutdown must not blanket-overwrite it to
|
||||
``interrupted`` in memory or in the store (Copilot review on PR #3381)."""
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
store = MemoryRunStore()
|
||||
rm = RunManager(store=store)
|
||||
record = await rm.create("t-complete")
|
||||
await rm.set_status(record.run_id, RunStatus.running)
|
||||
|
||||
async def worker() -> None:
|
||||
try:
|
||||
await asyncio.Event().wait()
|
||||
except asyncio.CancelledError:
|
||||
# The run had effectively finished; swallow the cancellation and
|
||||
# record success, like a run that completed in the same tick the
|
||||
# shutdown cancelled it.
|
||||
pass
|
||||
await rm.set_status(record.run_id, RunStatus.success)
|
||||
|
||||
record.task = asyncio.create_task(worker())
|
||||
try:
|
||||
await asyncio.sleep(0) # let the task reach its await point
|
||||
|
||||
await rm.shutdown(timeout=5.0)
|
||||
|
||||
assert record.status == RunStatus.success, f"shutdown overwrote in-memory status: {record.status}"
|
||||
persisted = await store.get(record.run_id)
|
||||
assert persisted is not None and persisted["status"] == "success", f"shutdown overwrote persisted status: {persisted}"
|
||||
finally:
|
||||
if not record.task.done():
|
||||
record.task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await record.task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shutdown_surfaces_failed_interrupted_persist(caplog):
|
||||
"""A failed interrupted-status persist during the drain must be surfaced (with
|
||||
the run_id), not silently swallowed by the gather (maintainer review on
|
||||
PR #3381)."""
|
||||
import logging
|
||||
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
class _FailingStore(MemoryRunStore):
|
||||
async def update_status(self, *args, **kwargs):
|
||||
raise RuntimeError("store unavailable")
|
||||
|
||||
rm = RunManager(store=_FailingStore())
|
||||
record = await rm.create("t-failpersist")
|
||||
record.status = RunStatus.running # set in memory; the failing store is exercised by the drain
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def worker() -> None:
|
||||
started.set()
|
||||
await asyncio.Event().wait() # blocks until cancelled by the drain
|
||||
|
||||
record.task = asyncio.create_task(worker())
|
||||
try:
|
||||
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||
with caplog.at_level(logging.WARNING, logger="deerflow.runtime.runs.manager"):
|
||||
await rm.shutdown(timeout=5.0)
|
||||
assert "Could not persist interrupted status for run" in caplog.text, caplog.text
|
||||
finally:
|
||||
if not record.task.done():
|
||||
record.task.cancel()
|
||||
with suppress(asyncio.CancelledError):
|
||||
await record.task
|
||||
Reference in New Issue
Block a user