fix(gateway): honour on_disconnect on /wait endpoints (#3267)

* fix(gateway): honour on_disconnect on /wait endpoints (#3265)

The non-streaming /threads/{tid}/runs/wait and /runs/wait handlers used
to await record.task directly with no disconnect handling and silently
swallow CancelledError. When a long tool call (e.g. pip install inside
a custom skill) kept the connection idle long enough for an
intermediate HTTP layer to time out, the handler would still read the
in-progress checkpoint and return it as if the run had completed
normally -- masking a half-finished run as a successful response.

Add wait_for_run_completion in app.gateway.services that mirrors
sse_consumer's bridge-consumption pattern: subscribe to the stream
bridge until END_SENTINEL, poll request.is_disconnected on every
wake-up, and on real client disconnect cancel the background run when
record.on_disconnect is "cancel". Wire it into both wait endpoints.

The streaming path was unaffected because sse_consumer already has
this loop; this just brings /wait to parity.

* fix(gateway): skip checkpoint serialization on /wait disconnect

Copilot review on #3267 caught a follow-on of the same #3265 bug: when
the client disconnects, wait_for_run_completion breaks out of the bridge
loop and cancels the run, but the /wait endpoint then continues to read
the checkpointer and serializes whatever partial checkpoint exists as a
normal 200 response.

Have the helper return a bool — True only when END_SENTINEL was observed
— and skip the checkpoint serialization path on False. Also reorder the
inner check so END_SENTINEL is honoured even when is_disconnected() flips
true in the same iteration; the run truly finished so the real final
checkpoint is still valid.
This commit is contained in:
AochenShen99
2026-05-28 07:22:39 +08:00
committed by GitHub
parent 9e332c594a
commit a5599c100c
5 changed files with 258 additions and 31 deletions
+16 -16
View File
@@ -7,7 +7,6 @@ is reused so that conversation history is preserved across calls.
from __future__ import annotations
import asyncio
import logging
import uuid
@@ -17,7 +16,7 @@ from fastapi.responses import StreamingResponse
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.routers.thread_runs import RunCreateRequest
from app.gateway.services import sse_consumer, start_run
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
from deerflow.runtime import serialize_channel_values
logger = logging.getLogger(__name__)
@@ -66,24 +65,25 @@ async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
Otherwise a new temporary thread is created.
"""
thread_id = _resolve_thread_id(body)
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = await start_run(body, thread_id, request)
completed = True
if record.task is not None:
try:
await record.task
except asyncio.CancelledError:
pass
completed = await wait_for_run_completion(bridge, record, request, run_mgr)
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
if completed:
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
return {"status": record.status.value, "error": record.error}
+16 -15
View File
@@ -21,7 +21,7 @@ from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.services import sse_consumer, start_run
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values
logger = logging.getLogger(__name__)
@@ -175,24 +175,25 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -
@require_permission("runs", "create", owner_check=True, require_existing=True)
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
"""Create a run and block until it completes, returning the final state."""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = await start_run(body, thread_id, request)
completed = True
if record.task is not None:
try:
await record.task
except asyncio.CancelledError:
pass
completed = await wait_for_run_completion(bridge, record, request, run_mgr)
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
if completed:
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
return {"status": record.status.value, "error": record.error}
+48
View File
@@ -402,3 +402,51 @@ async def sse_consumer(
if record.status in (RunStatus.pending, RunStatus.running):
if record.on_disconnect == DisconnectMode.cancel:
await run_mgr.cancel(record.run_id)
async def wait_for_run_completion(
bridge: StreamBridge,
record: RunRecord,
request: Request,
run_mgr: RunManager,
) -> bool:
"""Block until the run publishes ``END_SENTINEL``, honouring on_disconnect.
The non-streaming ``/wait`` endpoints used to ``await record.task``
directly with no disconnect handling. When the client (or an
intermediate HTTP proxy) timed out during a long tool call such as
``pip install``, the handler would swallow ``CancelledError`` and
serialize whatever checkpoint happened to exist — masking a half-finished
run as a normal completion (issue #3265).
This helper consumes the same bridge that ``sse_consumer`` does so the
wait path shares its disconnect semantics: each wake-up polls
``request.is_disconnected()``; on a real disconnect it cancels the
background run when ``record.on_disconnect`` is ``cancel``. The bridge's
heartbeat sentinels guarantee at least one wake-up per
``heartbeat_interval`` even when the agent emits no events for a while.
Returns:
``True`` when ``END_SENTINEL`` was observed (run reached a terminal
state), ``False`` when the loop exited because the client
disconnected. Callers must skip checkpoint serialization on
``False`` so a partial checkpoint is not returned as a normal
response.
"""
completed = False
try:
async for entry in bridge.subscribe(record.run_id):
# END_SENTINEL means the run reached a terminal state; honour it
# even if the client just disconnected so the caller still serializes
# the real final checkpoint.
if entry is END_SENTINEL:
completed = True
return True
if await request.is_disconnected():
break
# Heartbeats and regular events: keep waiting for END_SENTINEL.
return completed
finally:
if not completed and record.status in (RunStatus.pending, RunStatus.running):
if record.on_disconnect == DisconnectMode.cancel:
await run_mgr.cancel(record.run_id)