mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 09:25:57 +00:00
fix(gateway): drain in-flight runs before closing checkpointer on shutdown (#3381)
* fix(gateway): drain in-flight runs before closing checkpointer on shutdown Chat runs execute in fire-and-forget background asyncio tasks that write checkpoints through a shared checkpointer. On shutdown, langgraph_runtime's AsyncExitStack tore down the checkpointer's postgres connection pool while those run tasks were still mid-graph. langgraph's AsyncPregelLoop._checkpointer_put_after_previous then ran its `finally: await checkpointer.aput(...)` against the closed pool, raising psycopg_pool.PoolClosed. Because that put runs in a langgraph-internal task (not on run_agent's call stack), run_agent's try/except cannot catch it and it surfaces as "unhandled exception during asyncio.run() shutdown". Add RunManager.shutdown() to cancel and bounded-await all in-flight runs, and call it from langgraph_runtime BEFORE the AsyncExitStack closes the checkpointer, so the final checkpoint write lands while the pool is still open. The drain is bounded by a timeout so a stuck run cannot hang worker shutdown, and is shielded so a second shutdown signal cannot abandon it mid-drain and reopen the race. Closes #3373 * fix(gateway): address review — preserve completed-run status, bound drain persistence Addresses Copilot review on #3381: - RunManager.shutdown(): decide run status AFTER the drain. Under the lock it now only requests cancellation; after asyncio.wait it marks/persists `interrupted` only for runs still pending or ended cancelled. A run that completes (e.g. `success`) during the drain window keeps its real terminal status instead of being unconditionally overwritten. - Bound the trailing status persistence within the timeout budget (deadline = loop.time()+timeout; gather wrapped in asyncio.wait_for) so a slow store backing off under DB pressure cannot push shutdown past the deadline. - deps: use asyncio.create_task instead of asyncio.ensure_future. - tests: wait deterministically for the run to be in-flight (poll the first checkpoint) instead of a fixed sleep; init shutdown_calls explicitly in the recovery test double; add regression test asserting a run completing during the drain keeps its status (in memory and in the store). * fix(gateway): address maintainer review — surface failed drain persists, clarify timeout constant Addresses @WillemJiang review on #3381: - shutdown(): inspect the gather result of the trailing interrupted-status persistence. _persist_status is best-effort (it catches + logs its own failure with exc_info and returns False, so it never raises out of the gather), but the aggregate result was never checked — a partial failure had no shutdown-level visibility. Now any escaped Exception is logged, and any False (a persist that did not confirm) is logged with the run_id. Added regression test test_shutdown_surfaces_failed_interrupted_persist. - deps: clarify the _RUN_DRAIN_TIMEOUT_SECONDS comment — state the actual value of _SHUTDOWN_HOOK_TIMEOUT_SECONDS (5.0s) and that both count toward the lifespan shutdown window. Kept as two separate constants (independent teardown steps that may diverge) rather than one shared "must match" value. - Verified no other test fake needs the shutdown stub: _FakeRunManager in test_worker_langfuse_metadata.py is a run_agent() argument (worker path), never injected into langgraph_runtime, so it never receives shutdown().
This commit is contained in:
@@ -17,6 +17,7 @@ Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import AsyncGenerator, Callable
|
from collections.abc import AsyncGenerator, Callable
|
||||||
from contextlib import AsyncExitStack, asynccontextmanager
|
from contextlib import AsyncExitStack, asynccontextmanager
|
||||||
@@ -33,6 +34,43 @@ from deerflow.runtime.runs.store.base import RunStore
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Upper bound (seconds) for draining in-flight runs during shutdown, before the
|
||||||
|
# AsyncExitStack tears down the checkpointer (and its connection pool). Kept
|
||||||
|
# local to avoid an app -> deps -> app import cycle. This is a *separate* budget
|
||||||
|
# from ``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS`` (currently also 5.0s,
|
||||||
|
# which bounds channel-service stop): the two govern independent teardown steps
|
||||||
|
# and may diverge, but both count toward the lifespan shutdown window — revisit
|
||||||
|
# them together if their sum must stay within the server's graceful-shutdown
|
||||||
|
# timeout.
|
||||||
|
_RUN_DRAIN_TIMEOUT_SECONDS = 5.0
|
||||||
|
|
||||||
|
|
||||||
|
async def _drain_inflight_runs(run_manager: RunManager) -> None:
|
||||||
|
"""Drain in-flight runs before the checkpointer is torn down (issue #3373).
|
||||||
|
|
||||||
|
Shields the (internally-bounded) drain so that even if the lifespan
|
||||||
|
coroutine is itself cancelled mid-shutdown — a second SIGINT or the server's
|
||||||
|
graceful-shutdown timeout, i.e. the same signal storm behind #3373 — the
|
||||||
|
checkpointer pool is not closed while run tasks are still writing
|
||||||
|
checkpoints. On such a cancellation we let the already-running drain finish
|
||||||
|
(it is bounded by ``RunManager.shutdown``'s own timeout) and then propagate
|
||||||
|
the cancellation.
|
||||||
|
"""
|
||||||
|
drain = asyncio.create_task(run_manager.shutdown(timeout=_RUN_DRAIN_TIMEOUT_SECONDS))
|
||||||
|
try:
|
||||||
|
await asyncio.shield(drain)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Re-shield so this second wait does not abandon the in-flight drain;
|
||||||
|
# it is bounded, so this cannot hang. Then re-raise to honour shutdown.
|
||||||
|
try:
|
||||||
|
await asyncio.shield(drain)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("In-flight run drain failed after shutdown cancellation")
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to drain in-flight runs during shutdown")
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||||
@@ -177,6 +215,14 @@ async def langgraph_runtime(app: FastAPI, startup_config: AppConfig) -> AsyncGen
|
|||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
# Drain in-flight run tasks BEFORE the AsyncExitStack tears down the
|
||||||
|
# checkpointer (and its connection pool). A run still mid-graph would
|
||||||
|
# otherwise leak into asyncio.run() shutdown, where langgraph's
|
||||||
|
# _checkpointer_put_after_previous aput races the closed pool and
|
||||||
|
# raises PoolClosed (issue #3373).
|
||||||
|
run_manager = getattr(app.state, "run_manager", None)
|
||||||
|
if run_manager is not None:
|
||||||
|
await _drain_inflight_runs(run_manager)
|
||||||
await close_engine()
|
await close_engine()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -645,6 +645,98 @@ class RunManager:
|
|||||||
self._runs.pop(run_id, None)
|
self._runs.pop(run_id, None)
|
||||||
logger.debug("Run record %s cleaned up", run_id)
|
logger.debug("Run record %s cleaned up", run_id)
|
||||||
|
|
||||||
|
async def shutdown(self, *, timeout: float = 5.0) -> None:
|
||||||
|
"""Cancel and bounded-await all in-flight runs on process shutdown.
|
||||||
|
|
||||||
|
Chat runs execute in fire-and-forget background ``asyncio`` tasks that
|
||||||
|
write checkpoints through a shared checkpointer. On shutdown the
|
||||||
|
checkpointer's resources (e.g. the postgres connection pool owned by the
|
||||||
|
gateway's ``AsyncExitStack``) are torn down; if a run task is still
|
||||||
|
mid-graph at that point, langgraph's
|
||||||
|
``AsyncPregelLoop._checkpointer_put_after_previous`` runs its
|
||||||
|
``finally: await checkpointer.aput(...)`` against the closed pool. Because
|
||||||
|
that put runs in a langgraph-internal task (not on ``run_agent``'s call
|
||||||
|
stack), the resulting ``psycopg_pool.PoolClosed`` is not catchable by the
|
||||||
|
worker and surfaces as an unhandled exception during ``asyncio.run()``
|
||||||
|
shutdown (bytedance/deer-flow issue #3373).
|
||||||
|
|
||||||
|
Draining in-flight runs *before* the checkpointer is closed lets each
|
||||||
|
run that settles within ``timeout`` flush its final checkpoint while
|
||||||
|
resources are still open. Only runs that do **not** settle on their own
|
||||||
|
are marked ``interrupted`` — a run that completes (e.g. ``success``)
|
||||||
|
during the drain keeps its real terminal status instead of being
|
||||||
|
blanket-overwritten. The whole drain, including the trailing status
|
||||||
|
persistence, is bounded by ``timeout`` so a run stuck in cleanup (or a
|
||||||
|
slow store under DB pressure) cannot hang worker shutdown — the
|
||||||
|
precondition for the signal-reentrancy deadlock guarded by
|
||||||
|
``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS``. Runs still active
|
||||||
|
after ``timeout`` are logged and may still race teardown.
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
deadline = loop.time() + timeout
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
inflight = [record for record in self._runs.values() if record.status in (RunStatus.pending, RunStatus.running) and record.task is not None and not record.task.done()]
|
||||||
|
for record in inflight:
|
||||||
|
record.abort_action = "interrupt"
|
||||||
|
record.abort_event.set()
|
||||||
|
record.task.cancel() # type: ignore[union-attr] # filtered above
|
||||||
|
# Status is decided AFTER the drain (below), not here: a run that
|
||||||
|
# completes on its own during the drain must keep its real status.
|
||||||
|
|
||||||
|
if not inflight:
|
||||||
|
return
|
||||||
|
|
||||||
|
tasks = [record.task for record in inflight]
|
||||||
|
_, pending = await asyncio.wait(tasks, timeout=timeout)
|
||||||
|
|
||||||
|
# Only mark/persist ``interrupted`` for runs that did not settle on their
|
||||||
|
# own (still pending after the timeout, or ended cancelled). A run that
|
||||||
|
# finished normally during the drain keeps the status it set for itself.
|
||||||
|
to_persist: list[RunRecord] = []
|
||||||
|
async with self._lock:
|
||||||
|
for record in inflight:
|
||||||
|
task = record.task
|
||||||
|
if task not in pending and not task.cancelled():
|
||||||
|
# Completed on its own — retrieve any surfaced exception so it
|
||||||
|
# is not reported as "never retrieved", and keep its status.
|
||||||
|
task.exception() # type: ignore[union-attr] # done & not cancelled
|
||||||
|
continue
|
||||||
|
if record.status in (RunStatus.pending, RunStatus.running):
|
||||||
|
record.status = RunStatus.interrupted
|
||||||
|
record.updated_at = _now_iso()
|
||||||
|
to_persist.append(record)
|
||||||
|
|
||||||
|
# Bound the trailing status persistence within the remaining budget so a
|
||||||
|
# slow store (``_call_store_with_retry`` can back off under DB pressure)
|
||||||
|
# cannot push shutdown past ``timeout``.
|
||||||
|
if to_persist:
|
||||||
|
remaining = deadline - loop.time()
|
||||||
|
if remaining <= 0:
|
||||||
|
logger.warning("Run drain budget exhausted before persisting %d interrupted run(s) on shutdown", len(to_persist))
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
results = await asyncio.wait_for(
|
||||||
|
asyncio.gather(*(self._persist_status(record, RunStatus.interrupted) for record in to_persist), return_exceptions=True),
|
||||||
|
timeout=remaining,
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
logger.warning("Run drain status persistence exceeded the %.1fs budget; %d record(s) may not be persisted", timeout, len(to_persist))
|
||||||
|
else:
|
||||||
|
# ``_persist_status`` is best-effort: it catches and logs its
|
||||||
|
# own failures, returning ``False``. Inspect the aggregate so a
|
||||||
|
# partial failure is surfaced at shutdown level (with the
|
||||||
|
# run_id) instead of being silently swallowed by the gather.
|
||||||
|
for record, result in zip(to_persist, results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.warning("Unexpected error persisting interrupted status for run %s during shutdown: %r", record.run_id, result)
|
||||||
|
elif result is False:
|
||||||
|
logger.warning("Could not persist interrupted status for run %s during shutdown", record.run_id)
|
||||||
|
|
||||||
|
if pending:
|
||||||
|
logger.warning("Run drain exceeded %.1fs on shutdown; %d run task(s) still active and may race checkpointer teardown", timeout, len(pending))
|
||||||
|
logger.info("Drained %d in-flight run(s) on shutdown (%d settled within %.1fs)", len(inflight), len(inflight) - len(pending), timeout)
|
||||||
|
|
||||||
|
|
||||||
class ConflictError(Exception):
|
class ConflictError(Exception):
|
||||||
"""Raised when multitask_strategy=reject and thread has inflight runs."""
|
"""Raised when multitask_strategy=reject and thread has inflight runs."""
|
||||||
|
|||||||
@@ -0,0 +1,353 @@
|
|||||||
|
"""Regression tests for graceful run-task drain on Gateway shutdown.
|
||||||
|
|
||||||
|
Guards bytedance/deer-flow issue #3373:
|
||||||
|
|
||||||
|
psycopg_pool.PoolClosed: the pool 'pool-1' is already closed
|
||||||
|
|
||||||
|
Root cause: chat runs are fire-and-forget background ``asyncio`` tasks
|
||||||
|
(``app/gateway/services.py`` -> ``asyncio.create_task(run_agent(...))``) owned
|
||||||
|
by nobody. On shutdown, ``langgraph_runtime``'s ``AsyncExitStack`` tore down the
|
||||||
|
checkpointer's postgres pool while those tasks were still mid-graph. langgraph's
|
||||||
|
``AsyncPregelLoop._checkpointer_put_after_previous`` then ran its
|
||||||
|
``finally: await checkpointer.aput(...)`` against the already-closed pool.
|
||||||
|
|
||||||
|
Fix: ``RunManager.shutdown()`` cancels and *bounded*-awaits every in-flight run,
|
||||||
|
and ``langgraph_runtime`` calls it BEFORE the ``AsyncExitStack`` closes the
|
||||||
|
checkpointer — so the final checkpoint write lands while the pool is still open.
|
||||||
|
The drain must stay bounded (a stuck run must not hang the worker, the
|
||||||
|
precondition for the signal-reentrancy deadlock guarded by
|
||||||
|
``app.gateway.app._SHUTDOWN_HOOK_TIMEOUT_SECONDS``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import operator
|
||||||
|
from contextlib import asynccontextmanager, suppress
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Annotated, TypedDict
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
|
||||||
|
from deerflow.runtime import RunManager, RunStatus
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level so langgraph's get_type_hints (which resolves annotations against
|
||||||
|
# module globals under `from __future__ import annotations`) can see Annotated.
|
||||||
|
class _CountState(TypedDict):
|
||||||
|
count: Annotated[int, operator.add]
|
||||||
|
|
||||||
|
|
||||||
|
class _CloseableSaver(InMemorySaver):
|
||||||
|
"""InMemorySaver that fails writes once closed, like a closed pool."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._closed = False
|
||||||
|
self.writes_after_close: list[str] = []
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self._closed = True
|
||||||
|
|
||||||
|
async def aput(self, *args, **kwargs):
|
||||||
|
if self._closed:
|
||||||
|
self.writes_after_close.append("aput")
|
||||||
|
raise RuntimeError("checkpointer is closed")
|
||||||
|
return await super().aput(*args, **kwargs)
|
||||||
|
|
||||||
|
async def aput_writes(self, *args, **kwargs):
|
||||||
|
if self._closed:
|
||||||
|
self.writes_after_close.append("aput_writes")
|
||||||
|
raise RuntimeError("checkpointer is closed")
|
||||||
|
return await super().aput_writes(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_shutdown_cancels_and_awaits_inflight_run():
|
||||||
|
"""shutdown() cancels the in-flight task, waits for it, marks it interrupted."""
|
||||||
|
rm = RunManager()
|
||||||
|
record = await rm.create("t-drain")
|
||||||
|
await rm.set_status(record.run_id, RunStatus.running)
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
cancelled = asyncio.Event()
|
||||||
|
|
||||||
|
async def worker() -> None:
|
||||||
|
try:
|
||||||
|
started.set()
|
||||||
|
await asyncio.Event().wait()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
cancelled.set()
|
||||||
|
raise
|
||||||
|
|
||||||
|
record.task = asyncio.create_task(worker())
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||||
|
|
||||||
|
await rm.shutdown(timeout=5.0)
|
||||||
|
|
||||||
|
assert record.task.done()
|
||||||
|
assert cancelled.is_set()
|
||||||
|
assert record.status == RunStatus.interrupted
|
||||||
|
finally:
|
||||||
|
if not record.task.done():
|
||||||
|
record.task.cancel()
|
||||||
|
with suppress(asyncio.CancelledError):
|
||||||
|
await record.task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_shutdown_is_bounded_when_run_ignores_cancellation():
|
||||||
|
"""A run that swallows cancellation must not make shutdown() hang."""
|
||||||
|
rm = RunManager()
|
||||||
|
record = await rm.create("t-stubborn")
|
||||||
|
await rm.set_status(record.run_id, RunStatus.running)
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
stop = asyncio.Event()
|
||||||
|
|
||||||
|
async def stubborn() -> None:
|
||||||
|
started.set()
|
||||||
|
while not stop.is_set():
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(3600)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
if stop.is_set():
|
||||||
|
raise
|
||||||
|
# else: swallow — simulates a run stuck in slow cleanup
|
||||||
|
|
||||||
|
record.task = asyncio.create_task(stubborn())
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
t0 = loop.time()
|
||||||
|
await rm.shutdown(timeout=0.3)
|
||||||
|
elapsed = loop.time() - t0
|
||||||
|
|
||||||
|
assert elapsed < 2.0, f"shutdown took {elapsed:.2f}s; drain is not bounded"
|
||||||
|
finally:
|
||||||
|
# cleanup the deliberately-stubborn task
|
||||||
|
stop.set()
|
||||||
|
record.task.cancel()
|
||||||
|
with suppress(asyncio.CancelledError):
|
||||||
|
await record.task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_shutdown_is_noop_without_inflight_runs():
|
||||||
|
"""shutdown() on an idle manager completes cleanly and is idempotent."""
|
||||||
|
rm = RunManager()
|
||||||
|
await rm.shutdown(timeout=1.0)
|
||||||
|
# already-finished runs must not be re-cancelled or error out
|
||||||
|
record = await rm.create("t-done")
|
||||||
|
await rm.set_status(record.run_id, RunStatus.success)
|
||||||
|
await rm.shutdown(timeout=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_langgraph_runtime_drains_runs_before_closing_checkpointer(monkeypatch):
|
||||||
|
"""The wiring order lock for #3373: drain in-flight runs, THEN close the pool.
|
||||||
|
|
||||||
|
Patches every ``langgraph_runtime`` collaborator down to trivial stand-ins so
|
||||||
|
only the bootstrap/teardown ordering runs. The checkpointer probe records when
|
||||||
|
its context manager exits (pool close); a ``RunManager.shutdown`` spy records
|
||||||
|
when the drain happens. The drain MUST come first.
|
||||||
|
"""
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from app.gateway.deps import langgraph_runtime
|
||||||
|
|
||||||
|
events: list[str] = []
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def probe_checkpointer(_config):
|
||||||
|
try:
|
||||||
|
yield object()
|
||||||
|
finally:
|
||||||
|
events.append("checkpointer_closed")
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def fake_stream_bridge(_config):
|
||||||
|
yield object()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def fake_store(_config):
|
||||||
|
yield object()
|
||||||
|
|
||||||
|
async def fake_init_engine(_db):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def fake_close_engine():
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def spy_shutdown(self, *, timeout): # noqa: ANN001
|
||||||
|
events.append("runs_drained")
|
||||||
|
|
||||||
|
monkeypatch.setattr("deerflow.runtime.checkpointer.async_provider.make_checkpointer", probe_checkpointer)
|
||||||
|
monkeypatch.setattr("deerflow.runtime.make_stream_bridge", fake_stream_bridge)
|
||||||
|
monkeypatch.setattr("deerflow.runtime.make_store", fake_store)
|
||||||
|
monkeypatch.setattr("deerflow.persistence.engine.init_engine_from_config", fake_init_engine)
|
||||||
|
monkeypatch.setattr("deerflow.persistence.engine.close_engine", fake_close_engine)
|
||||||
|
monkeypatch.setattr("deerflow.persistence.engine.get_session_factory", lambda: None)
|
||||||
|
monkeypatch.setattr("deerflow.runtime.events.store.make_run_event_store", lambda _cfg: object())
|
||||||
|
monkeypatch.setattr("deerflow.persistence.thread_meta.make_thread_store", lambda _sf, _store: object())
|
||||||
|
monkeypatch.setattr(RunManager, "shutdown", spy_shutdown, raising=False)
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
startup_config = SimpleNamespace(database=SimpleNamespace(backend="memory"), run_events=None)
|
||||||
|
|
||||||
|
async with langgraph_runtime(app, startup_config):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert "runs_drained" in events, "langgraph_runtime never drained in-flight runs on shutdown"
|
||||||
|
assert "checkpointer_closed" in events
|
||||||
|
assert events.index("runs_drained") < events.index("checkpointer_closed"), f"runs must be drained before the checkpointer pool is closed; got order {events}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_drain_flushes_real_graph_checkpoint_before_close():
|
||||||
|
"""End-to-end #3373 guard with a REAL langgraph graph + checkpointer.
|
||||||
|
|
||||||
|
A real run is driven through ``graph.astream`` in a background task, then
|
||||||
|
``RunManager.shutdown()`` drains it. The checkpointer raises once closed
|
||||||
|
(mirroring ``psycopg_pool.PoolClosed``). Closing only happens AFTER the
|
||||||
|
drain — as the gateway's AsyncExitStack does. The drain must let langgraph
|
||||||
|
flush its final checkpoint while the checkpointer is still open, so no write
|
||||||
|
lands against a closed checkpointer.
|
||||||
|
|
||||||
|
Unlike the unit/spy tests above, this exercises the real langgraph
|
||||||
|
checkpoint-put machinery, so a future langgraph change that cancels (rather
|
||||||
|
than awaits) its checkpoint-put task on executor exit would fail this test
|
||||||
|
instead of silently regressing #3373.
|
||||||
|
"""
|
||||||
|
from langgraph.graph import END, START, StateGraph
|
||||||
|
|
||||||
|
async def slow(_state: _CountState) -> dict:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
return {"count": 1}
|
||||||
|
|
||||||
|
saver = _CloseableSaver()
|
||||||
|
builder = StateGraph(_CountState)
|
||||||
|
for name in ("a", "b", "c"):
|
||||||
|
builder.add_node(name, slow)
|
||||||
|
builder.add_edge(START, "a")
|
||||||
|
builder.add_edge("a", "b")
|
||||||
|
builder.add_edge("b", "c")
|
||||||
|
builder.add_edge("c", END)
|
||||||
|
graph = builder.compile(checkpointer=saver)
|
||||||
|
|
||||||
|
rm = RunManager()
|
||||||
|
record = await rm.create("t-e2e")
|
||||||
|
await rm.set_status(record.run_id, RunStatus.running)
|
||||||
|
thread_cfg = {"configurable": {"thread_id": "t-e2e"}}
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
|
||||||
|
async def run() -> None:
|
||||||
|
started.set()
|
||||||
|
async for _ in graph.astream({"count": 0}, config=thread_cfg):
|
||||||
|
pass
|
||||||
|
|
||||||
|
record.task = asyncio.create_task(run())
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||||
|
|
||||||
|
# Deterministically wait until the run is genuinely in-flight — poll for
|
||||||
|
# the first persisted checkpoint instead of a fixed sleep (avoids CI
|
||||||
|
# flakiness on slow runners / under event-loop contention).
|
||||||
|
async def _await_first_checkpoint() -> None:
|
||||||
|
while (await saver.aget_tuple(thread_cfg)) is None:
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
await asyncio.wait_for(_await_first_checkpoint(), timeout=5.0)
|
||||||
|
|
||||||
|
# The fix: drain while the checkpointer is still open ...
|
||||||
|
await rm.shutdown(timeout=5.0)
|
||||||
|
# ... and only then close it (mirrors langgraph_runtime's ExitStack).
|
||||||
|
saver.close()
|
||||||
|
|
||||||
|
assert saver.writes_after_close == [], f"a checkpoint write raced a closed checkpointer: {saver.writes_after_close}"
|
||||||
|
# The final checkpoint landed before close.
|
||||||
|
snapshot = await saver.aget_tuple(thread_cfg)
|
||||||
|
assert snapshot is not None
|
||||||
|
finally:
|
||||||
|
if not record.task.done():
|
||||||
|
record.task.cancel()
|
||||||
|
with suppress(asyncio.CancelledError):
|
||||||
|
await record.task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_shutdown_preserves_status_of_run_completed_during_drain():
|
||||||
|
"""A run that finishes (e.g. success) during the drain window must keep its
|
||||||
|
real terminal status — shutdown must not blanket-overwrite it to
|
||||||
|
``interrupted`` in memory or in the store (Copilot review on PR #3381)."""
|
||||||
|
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||||
|
|
||||||
|
store = MemoryRunStore()
|
||||||
|
rm = RunManager(store=store)
|
||||||
|
record = await rm.create("t-complete")
|
||||||
|
await rm.set_status(record.run_id, RunStatus.running)
|
||||||
|
|
||||||
|
async def worker() -> None:
|
||||||
|
try:
|
||||||
|
await asyncio.Event().wait()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# The run had effectively finished; swallow the cancellation and
|
||||||
|
# record success, like a run that completed in the same tick the
|
||||||
|
# shutdown cancelled it.
|
||||||
|
pass
|
||||||
|
await rm.set_status(record.run_id, RunStatus.success)
|
||||||
|
|
||||||
|
record.task = asyncio.create_task(worker())
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(0) # let the task reach its await point
|
||||||
|
|
||||||
|
await rm.shutdown(timeout=5.0)
|
||||||
|
|
||||||
|
assert record.status == RunStatus.success, f"shutdown overwrote in-memory status: {record.status}"
|
||||||
|
persisted = await store.get(record.run_id)
|
||||||
|
assert persisted is not None and persisted["status"] == "success", f"shutdown overwrote persisted status: {persisted}"
|
||||||
|
finally:
|
||||||
|
if not record.task.done():
|
||||||
|
record.task.cancel()
|
||||||
|
with suppress(asyncio.CancelledError):
|
||||||
|
await record.task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_shutdown_surfaces_failed_interrupted_persist(caplog):
|
||||||
|
"""A failed interrupted-status persist during the drain must be surfaced (with
|
||||||
|
the run_id), not silently swallowed by the gather (maintainer review on
|
||||||
|
PR #3381)."""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||||
|
|
||||||
|
class _FailingStore(MemoryRunStore):
|
||||||
|
async def update_status(self, *args, **kwargs):
|
||||||
|
raise RuntimeError("store unavailable")
|
||||||
|
|
||||||
|
rm = RunManager(store=_FailingStore())
|
||||||
|
record = await rm.create("t-failpersist")
|
||||||
|
record.status = RunStatus.running # set in memory; the failing store is exercised by the drain
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
|
||||||
|
async def worker() -> None:
|
||||||
|
started.set()
|
||||||
|
await asyncio.Event().wait() # blocks until cancelled by the drain
|
||||||
|
|
||||||
|
record.task = asyncio.create_task(worker())
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||||
|
with caplog.at_level(logging.WARNING, logger="deerflow.runtime.runs.manager"):
|
||||||
|
await rm.shutdown(timeout=5.0)
|
||||||
|
assert "Could not persist interrupted status for run" in caplog.text, caplog.text
|
||||||
|
finally:
|
||||||
|
if not record.task.done():
|
||||||
|
record.task.cancel()
|
||||||
|
with suppress(asyncio.CancelledError):
|
||||||
|
await record.task
|
||||||
@@ -32,6 +32,7 @@ class _FakeRunManager:
|
|||||||
self.store = store
|
self.store = store
|
||||||
self.reconcile_calls: list[dict] = []
|
self.reconcile_calls: list[dict] = []
|
||||||
self.list_by_thread_calls: list[dict] = []
|
self.list_by_thread_calls: list[dict] = []
|
||||||
|
self.shutdown_calls: int = 0
|
||||||
_FakeRunManager.instances.append(self)
|
_FakeRunManager.instances.append(self)
|
||||||
|
|
||||||
async def reconcile_orphaned_inflight_runs(self, *, error: str, before: str | None = None):
|
async def reconcile_orphaned_inflight_runs(self, *, error: str, before: str | None = None):
|
||||||
@@ -42,6 +43,11 @@ class _FakeRunManager:
|
|||||||
self.list_by_thread_calls.append({"thread_id": thread_id, "user_id": user_id, "limit": limit})
|
self.list_by_thread_calls.append({"thread_id": thread_id, "user_id": user_id, "limit": limit})
|
||||||
return self.latest_by_thread.get(thread_id, self.recovered_runs[:limit])
|
return self.latest_by_thread.get(thread_id, self.recovered_runs[:limit])
|
||||||
|
|
||||||
|
async def shutdown(self, *, timeout: float = 5.0) -> None:
|
||||||
|
# No in-flight tasks in these startup-recovery tests; langgraph_runtime
|
||||||
|
# drains the manager on teardown, so the double must accept the call.
|
||||||
|
self.shutdown_calls += 1
|
||||||
|
|
||||||
|
|
||||||
class _FakeThreadStore:
|
class _FakeThreadStore:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user