mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 09:25:57 +00:00
fix(gateway): honour on_disconnect on /wait endpoints (#3267)
* fix(gateway): honour on_disconnect on /wait endpoints (#3265) The non-streaming /threads/{tid}/runs/wait and /runs/wait handlers used to await record.task directly with no disconnect handling and silently swallow CancelledError. When a long tool call (e.g. pip install inside a custom skill) kept the connection idle long enough for an intermediate HTTP layer to time out, the handler would still read the in-progress checkpoint and return it as if the run had completed normally -- masking a half-finished run as a successful response. Add wait_for_run_completion in app.gateway.services that mirrors sse_consumer's bridge-consumption pattern: subscribe to the stream bridge until END_SENTINEL, poll request.is_disconnected on every wake-up, and on real client disconnect cancel the background run when record.on_disconnect is "cancel". Wire it into both wait endpoints. The streaming path was unaffected because sse_consumer already has this loop; this just brings /wait to parity. * fix(gateway): skip checkpoint serialization on /wait disconnect Copilot review on #3267 caught a follow-on of the same #3265 bug: when the client disconnects, wait_for_run_completion breaks out of the bridge loop and cancels the run, but the /wait endpoint then continues to read the checkpointer and serializes whatever partial checkpoint exists as a normal 200 response. Have the helper return a bool — True only when END_SENTINEL was observed — and skip the checkpoint serialization path on False. Also reorder the inner check so END_SENTINEL is honoured even when is_disconnected() flips true in the same iteration; the run truly finished so the real final checkpoint is still valid.
This commit is contained in:
@@ -7,7 +7,6 @@ is reused so that conversation history is preserved across calls.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
@@ -17,7 +16,7 @@ from fastapi.responses import StreamingResponse
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.routers.thread_runs import RunCreateRequest
|
||||
from app.gateway.services import sse_consumer, start_run
|
||||
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -66,24 +65,25 @@ async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
|
||||
Otherwise a new temporary thread is created.
|
||||
"""
|
||||
thread_id = _resolve_thread_id(body)
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
record = await start_run(body, thread_id, request)
|
||||
|
||||
completed = True
|
||||
if record.task is not None:
|
||||
try:
|
||||
await record.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
completed = await wait_for_run_completion(bridge, record, request, run_mgr)
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
if checkpoint_tuple is not None:
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
return serialize_channel_values(channel_values)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
if completed:
|
||||
checkpointer = get_checkpointer(request)
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
if checkpoint_tuple is not None:
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
return serialize_channel_values(channel_values)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
|
||||
return {"status": record.status.value, "error": record.error}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.services import sse_consumer, start_run
|
||||
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
|
||||
from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -175,24 +175,25 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -
|
||||
@require_permission("runs", "create", owner_check=True, require_existing=True)
|
||||
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
|
||||
"""Create a run and block until it completes, returning the final state."""
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
record = await start_run(body, thread_id, request)
|
||||
|
||||
completed = True
|
||||
if record.task is not None:
|
||||
try:
|
||||
await record.task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
completed = await wait_for_run_completion(bridge, record, request, run_mgr)
|
||||
|
||||
checkpointer = get_checkpointer(request)
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
if checkpoint_tuple is not None:
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
return serialize_channel_values(channel_values)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
if completed:
|
||||
checkpointer = get_checkpointer(request)
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
if checkpoint_tuple is not None:
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
return serialize_channel_values(channel_values)
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
||||
|
||||
return {"status": record.status.value, "error": record.error}
|
||||
|
||||
|
||||
@@ -402,3 +402,51 @@ async def sse_consumer(
|
||||
if record.status in (RunStatus.pending, RunStatus.running):
|
||||
if record.on_disconnect == DisconnectMode.cancel:
|
||||
await run_mgr.cancel(record.run_id)
|
||||
|
||||
|
||||
async def wait_for_run_completion(
|
||||
bridge: StreamBridge,
|
||||
record: RunRecord,
|
||||
request: Request,
|
||||
run_mgr: RunManager,
|
||||
) -> bool:
|
||||
"""Block until the run publishes ``END_SENTINEL``, honouring on_disconnect.
|
||||
|
||||
The non-streaming ``/wait`` endpoints used to ``await record.task``
|
||||
directly with no disconnect handling. When the client (or an
|
||||
intermediate HTTP proxy) timed out during a long tool call such as
|
||||
``pip install``, the handler would swallow ``CancelledError`` and
|
||||
serialize whatever checkpoint happened to exist — masking a half-finished
|
||||
run as a normal completion (issue #3265).
|
||||
|
||||
This helper consumes the same bridge that ``sse_consumer`` does so the
|
||||
wait path shares its disconnect semantics: each wake-up polls
|
||||
``request.is_disconnected()``; on a real disconnect it cancels the
|
||||
background run when ``record.on_disconnect`` is ``cancel``. The bridge's
|
||||
heartbeat sentinels guarantee at least one wake-up per
|
||||
``heartbeat_interval`` even when the agent emits no events for a while.
|
||||
|
||||
Returns:
|
||||
``True`` when ``END_SENTINEL`` was observed (run reached a terminal
|
||||
state), ``False`` when the loop exited because the client
|
||||
disconnected. Callers must skip checkpoint serialization on
|
||||
``False`` so a partial checkpoint is not returned as a normal
|
||||
response.
|
||||
"""
|
||||
completed = False
|
||||
try:
|
||||
async for entry in bridge.subscribe(record.run_id):
|
||||
# END_SENTINEL means the run reached a terminal state; honour it
|
||||
# even if the client just disconnected so the caller still serializes
|
||||
# the real final checkpoint.
|
||||
if entry is END_SENTINEL:
|
||||
completed = True
|
||||
return True
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
# Heartbeats and regular events: keep waiting for END_SENTINEL.
|
||||
return completed
|
||||
finally:
|
||||
if not completed and record.status in (RunStatus.pending, RunStatus.running):
|
||||
if record.on_disconnect == DisconnectMode.cancel:
|
||||
await run_mgr.cancel(record.run_id)
|
||||
|
||||
Reference in New Issue
Block a user