mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 17:06:00 +00:00
refactor(runtime): restructure runs module with new execution architecture
Major refactoring of deerflow/runtime/: - runs/callbacks/ - new callback system (builder, events, title, tokens) - runs/internal/ - execution internals (executor, supervisor, stream_logic, registry) - runs/internal/execution/ - execution artifacts and events handling - runs/facade.py - high-level run facade - runs/observer.py - run observation protocol - runs/types.py - type definitions - runs/store/ - simplified store interfaces (create, delete, query, event) Refactor stream_bridge/: - Replace old providers with contract.py and exceptions.py - Remove async_provider.py, base.py, memory.py Add documentation: - README.md and README_zh.md for runtime module Remove deprecated: - manager.py moved to internal/ - worker.py, schemas.py - user_context.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
"""Internal runs implementation modules.
|
||||
|
||||
These modules are implementation details behind the public runs surface.
|
||||
"""
|
||||
@@ -0,0 +1 @@
|
||||
"""Internal execution components for runs domain."""
|
||||
@@ -0,0 +1,64 @@
|
||||
"""Execution preparation helpers for a single run."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.runtime.stream_bridge import StreamBridge
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunBuildArtifacts:
|
||||
"""Assembled agent runtime pieces for a single run."""
|
||||
|
||||
agent: Any
|
||||
runnable_config: dict[str, Any]
|
||||
reference_store: Any | None = None
|
||||
|
||||
|
||||
def build_run_artifacts(
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
checkpointer: Any | None,
|
||||
store: Any | None,
|
||||
agent_factory: Any,
|
||||
config: dict[str, Any],
|
||||
bridge: StreamBridge,
|
||||
interrupt_before: list[str] | None = None,
|
||||
interrupt_after: list[str] | None = None,
|
||||
callbacks: list[BaseCallbackHandler] | None = None,
|
||||
) -> RunBuildArtifacts:
|
||||
"""Assemble all components needed for agent execution."""
|
||||
runtime = Runtime(context={"thread_id": thread_id}, store=store)
|
||||
if "context" in config and isinstance(config["context"], dict):
|
||||
config["context"].setdefault("thread_id", thread_id)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
config_callbacks = config.setdefault("callbacks", [])
|
||||
if callbacks:
|
||||
config_callbacks.extend(callbacks)
|
||||
|
||||
runnable_config = RunnableConfig(**config)
|
||||
agent = agent_factory(config=runnable_config)
|
||||
|
||||
if checkpointer is not None:
|
||||
agent.checkpointer = checkpointer
|
||||
if store is not None:
|
||||
agent.store = store
|
||||
|
||||
if interrupt_before:
|
||||
agent.interrupt_before_nodes = interrupt_before
|
||||
if interrupt_after:
|
||||
agent.interrupt_after_nodes = interrupt_after
|
||||
|
||||
return RunBuildArtifacts(
|
||||
agent=agent,
|
||||
runnable_config=dict(runnable_config),
|
||||
reference_store=store,
|
||||
)
|
||||
@@ -0,0 +1,45 @@
|
||||
"""Lifecycle event helpers for run execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from ...observer import LifecycleEventType, RunLifecycleEvent, RunObserver
|
||||
|
||||
|
||||
class RunEventEmitter:
|
||||
"""Build and dispatch lifecycle events for a single run."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
observer: RunObserver,
|
||||
) -> None:
|
||||
self._run_id = run_id
|
||||
self._thread_id = thread_id
|
||||
self._observer = observer
|
||||
self._sequence = 0
|
||||
|
||||
@property
|
||||
def sequence(self) -> int:
|
||||
return self._sequence
|
||||
|
||||
async def emit(
|
||||
self,
|
||||
event_type: LifecycleEventType,
|
||||
payload: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self._sequence += 1
|
||||
event = RunLifecycleEvent(
|
||||
event_id=f"{self._run_id}:{event_type.value}:{self._sequence}",
|
||||
event_type=event_type,
|
||||
run_id=self._run_id,
|
||||
thread_id=self._thread_id,
|
||||
sequence=self._sequence,
|
||||
occurred_at=datetime.now(UTC),
|
||||
payload=payload or {},
|
||||
)
|
||||
await self._observer.on_event(event)
|
||||
@@ -0,0 +1,376 @@
|
||||
"""Single-run execution orchestrator and execution-local helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from deerflow.runtime.serialization import serialize
|
||||
from deerflow.runtime.stream_bridge import StreamBridge, StreamStatus
|
||||
|
||||
from ...callbacks.builder import RunCallbackArtifacts, build_run_callbacks
|
||||
from ...observer import LifecycleEventType, RunObserver, RunResult
|
||||
from ...store import RunEventStore
|
||||
from ...types import RunStatus
|
||||
from .artifacts import build_run_artifacts
|
||||
from .events import RunEventEmitter
|
||||
from .stream_logic import external_stream_event_name, normalize_stream_modes, should_filter_event, unpack_stream_item
|
||||
from .supervisor import RunHandle
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _RunExecution:
|
||||
"""Encapsulate the lifecycle of a single run."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
bridge: StreamBridge,
|
||||
run_manager: Any,
|
||||
record: Any,
|
||||
checkpointer: Any | None = None,
|
||||
store: Any | None = None,
|
||||
event_store: RunEventStore | None = None,
|
||||
ctx: Any | None = None,
|
||||
agent_factory: Any,
|
||||
graph_input: dict,
|
||||
config: dict,
|
||||
observer: RunObserver,
|
||||
stream_modes: list[str] | None,
|
||||
stream_subgraphs: bool,
|
||||
interrupt_before: list[str] | Literal["*"] | None,
|
||||
interrupt_after: list[str] | Literal["*"] | None,
|
||||
handle: RunHandle | None = None,
|
||||
) -> None:
|
||||
if ctx is not None:
|
||||
checkpointer = getattr(ctx, "checkpointer", checkpointer)
|
||||
store = getattr(ctx, "store", store)
|
||||
|
||||
self.bridge = bridge
|
||||
self.run_manager = run_manager
|
||||
self.record = record
|
||||
self.checkpointer = checkpointer
|
||||
self.store = store
|
||||
self.event_store = event_store
|
||||
self.agent_factory = agent_factory
|
||||
self.graph_input = graph_input
|
||||
self.config = config
|
||||
self.observer = observer
|
||||
self.stream_modes = stream_modes
|
||||
self.stream_subgraphs = stream_subgraphs
|
||||
self.interrupt_before = interrupt_before
|
||||
self.interrupt_after = interrupt_after
|
||||
self.handle = handle
|
||||
|
||||
self.run_id = record.run_id
|
||||
self.thread_id = record.thread_id
|
||||
self._pre_run_checkpoint_id: str | None = None
|
||||
self._emitter = RunEventEmitter(
|
||||
run_id=self.run_id,
|
||||
thread_id=self.thread_id,
|
||||
observer=observer,
|
||||
)
|
||||
self.result = RunResult(
|
||||
run_id=self.run_id,
|
||||
thread_id=self.thread_id,
|
||||
status=RunStatus.pending,
|
||||
)
|
||||
self._agent: Any = None
|
||||
self._runnable_config: dict[str, Any] = {}
|
||||
self._lg_modes: list[str] = []
|
||||
self._callback_artifacts: RunCallbackArtifacts | None = None
|
||||
|
||||
@property
|
||||
def _event_sequence(self) -> int:
|
||||
return self._emitter.sequence
|
||||
|
||||
async def _emit(
|
||||
self,
|
||||
event_type: LifecycleEventType,
|
||||
payload: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
await self._emitter.emit(event_type, payload)
|
||||
|
||||
async def _start(self) -> None:
|
||||
await self.run_manager.set_status(self.run_id, RunStatus.running)
|
||||
await self._emit(LifecycleEventType.RUN_STARTED, {})
|
||||
|
||||
human_msg = self._extract_human_message()
|
||||
if human_msg is not None:
|
||||
await self._emit(
|
||||
LifecycleEventType.HUMAN_MESSAGE,
|
||||
{"message": human_msg.model_dump()},
|
||||
)
|
||||
|
||||
await self._capture_pre_run_checkpoint()
|
||||
await self.bridge.publish(
|
||||
self.run_id,
|
||||
"metadata",
|
||||
{"run_id": self.run_id, "thread_id": self.thread_id},
|
||||
)
|
||||
|
||||
def _extract_human_message(self) -> Any:
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
messages = self.graph_input.get("messages")
|
||||
if not messages:
|
||||
return None
|
||||
last = messages[-1] if isinstance(messages, list) else messages
|
||||
if isinstance(last, HumanMessage):
|
||||
return last
|
||||
if isinstance(last, str):
|
||||
return HumanMessage(content=last) if last else None
|
||||
if hasattr(last, "content"):
|
||||
return HumanMessage(content=last.content)
|
||||
if isinstance(last, dict):
|
||||
content = last.get("content", "")
|
||||
return HumanMessage(content=content) if content else None
|
||||
return None
|
||||
|
||||
async def _capture_pre_run_checkpoint(self) -> None:
|
||||
try:
|
||||
config_for_check = {"configurable": {"thread_id": self.thread_id, "checkpoint_ns": ""}}
|
||||
ckpt_tuple = await self.checkpointer.aget_tuple(config_for_check)
|
||||
if ckpt_tuple is not None:
|
||||
self._pre_run_checkpoint_id = (
|
||||
getattr(ckpt_tuple, "config", {})
|
||||
.get("configurable", {})
|
||||
.get("checkpoint_id")
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Could not get pre-run checkpoint_id for run %s", self.run_id)
|
||||
|
||||
async def _prepare(self) -> None:
|
||||
config = dict(self.config)
|
||||
existing_callbacks = config.pop("callbacks", [])
|
||||
if existing_callbacks is None:
|
||||
existing_callbacks = []
|
||||
elif not isinstance(existing_callbacks, list):
|
||||
existing_callbacks = [existing_callbacks]
|
||||
|
||||
self._callback_artifacts = build_run_callbacks(
|
||||
record=self.record,
|
||||
graph_input=self.graph_input,
|
||||
event_store=self.event_store,
|
||||
existing_callbacks=existing_callbacks,
|
||||
)
|
||||
|
||||
artifacts = build_run_artifacts(
|
||||
thread_id=self.thread_id,
|
||||
run_id=self.run_id,
|
||||
checkpointer=self.checkpointer,
|
||||
store=self.store,
|
||||
agent_factory=self.agent_factory,
|
||||
config=config,
|
||||
bridge=self.bridge,
|
||||
interrupt_before=self.interrupt_before,
|
||||
interrupt_after=self.interrupt_after,
|
||||
callbacks=self._callback_artifacts.callbacks,
|
||||
)
|
||||
|
||||
self._agent = artifacts.agent
|
||||
self._runnable_config = artifacts.runnable_config
|
||||
self._lg_modes = normalize_stream_modes(self.stream_modes)
|
||||
logger.info(
|
||||
"Run %s: streaming with modes %s (requested: %s)",
|
||||
self.run_id,
|
||||
self._lg_modes,
|
||||
self.stream_modes,
|
||||
)
|
||||
|
||||
async def _finish_success(self) -> None:
|
||||
await self.run_manager.set_status(self.run_id, RunStatus.success)
|
||||
await self.bridge.publish_terminal(self.run_id, StreamStatus.ENDED)
|
||||
self.result.status = RunStatus.success
|
||||
completion_data = self._completion_data()
|
||||
title = self._callback_title() or await self._extract_title_from_checkpoint()
|
||||
self.result.title = title
|
||||
self.result.completion_data = completion_data
|
||||
await self._emit(
|
||||
LifecycleEventType.RUN_COMPLETED,
|
||||
{
|
||||
"title": title,
|
||||
"completion_data": completion_data,
|
||||
},
|
||||
)
|
||||
|
||||
async def _finish_aborted(self, cancel_mode: str) -> None:
|
||||
payload = {
|
||||
"cancel_mode": cancel_mode,
|
||||
"pre_run_checkpoint_id": self._pre_run_checkpoint_id,
|
||||
"completion_data": self._completion_data(),
|
||||
}
|
||||
|
||||
if cancel_mode == "rollback":
|
||||
await self.run_manager.set_status(
|
||||
self.run_id,
|
||||
RunStatus.error,
|
||||
error="Rolled back by user",
|
||||
)
|
||||
await self.bridge.publish_terminal(
|
||||
self.run_id,
|
||||
StreamStatus.CANCELLED,
|
||||
{"cancel_mode": "rollback", "message": "Rolled back by user"},
|
||||
)
|
||||
self.result.status = RunStatus.error
|
||||
self.result.error = "Rolled back by user"
|
||||
logger.info("Run %s rolled back", self.run_id)
|
||||
else:
|
||||
await self.run_manager.set_status(self.run_id, RunStatus.interrupted)
|
||||
await self.bridge.publish_terminal(
|
||||
self.run_id,
|
||||
StreamStatus.CANCELLED,
|
||||
{"cancel_mode": cancel_mode},
|
||||
)
|
||||
self.result.status = RunStatus.interrupted
|
||||
logger.info("Run %s cancelled (mode=%s)", self.run_id, cancel_mode)
|
||||
|
||||
await self._emit(LifecycleEventType.RUN_CANCELLED, payload)
|
||||
|
||||
async def _finish_failed(self, exc: Exception) -> None:
|
||||
error_msg = str(exc)
|
||||
logger.exception("Run %s failed: %s", self.run_id, error_msg)
|
||||
|
||||
await self.run_manager.set_status(self.run_id, RunStatus.error, error=error_msg)
|
||||
await self.bridge.publish_terminal(
|
||||
self.run_id,
|
||||
StreamStatus.ERRORED,
|
||||
{"message": error_msg, "name": type(exc).__name__},
|
||||
)
|
||||
self.result.status = RunStatus.error
|
||||
self.result.error = error_msg
|
||||
|
||||
await self._emit(
|
||||
LifecycleEventType.RUN_FAILED,
|
||||
{
|
||||
"error": error_msg,
|
||||
"error_type": type(exc).__name__,
|
||||
"completion_data": self._completion_data(),
|
||||
},
|
||||
)
|
||||
|
||||
def _completion_data(self) -> dict[str, object]:
|
||||
if self._callback_artifacts is None:
|
||||
return {}
|
||||
return self._callback_artifacts.completion_data().to_dict()
|
||||
|
||||
def _callback_title(self) -> str | None:
|
||||
if self._callback_artifacts is None:
|
||||
return None
|
||||
return self._callback_artifacts.title()
|
||||
|
||||
async def _extract_title_from_checkpoint(self) -> str | None:
|
||||
if self.checkpointer is None:
|
||||
return None
|
||||
try:
|
||||
ckpt_config = {"configurable": {"thread_id": self.thread_id, "checkpoint_ns": ""}}
|
||||
ckpt_tuple = await self.checkpointer.aget_tuple(ckpt_config)
|
||||
if ckpt_tuple is not None:
|
||||
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
|
||||
return ckpt.get("channel_values", {}).get("title")
|
||||
except Exception:
|
||||
logger.debug("Failed to extract title from checkpoint for thread %s", self.thread_id)
|
||||
return None
|
||||
|
||||
def _map_run_status_to_thread_status(self, status: RunStatus) -> str:
|
||||
if status == RunStatus.success:
|
||||
return "idle"
|
||||
if status == RunStatus.interrupted:
|
||||
return "interrupted"
|
||||
if status in (RunStatus.error, RunStatus.timeout):
|
||||
return "error"
|
||||
return "running"
|
||||
|
||||
def _abort_requested(self) -> bool:
|
||||
if self.handle is not None:
|
||||
return self.handle.cancel_event.is_set()
|
||||
return self.record.abort_event.is_set()
|
||||
|
||||
def _abort_action(self) -> str:
|
||||
if self.handle is not None:
|
||||
return self.handle.cancel_action
|
||||
return self.record.abort_action
|
||||
|
||||
async def _stream(self) -> None:
|
||||
runnable_config = RunnableConfig(**self._runnable_config)
|
||||
|
||||
if len(self._lg_modes) == 1 and not self.stream_subgraphs:
|
||||
single_mode = self._lg_modes[0]
|
||||
async for chunk in self._agent.astream(
|
||||
self.graph_input,
|
||||
config=runnable_config,
|
||||
stream_mode=single_mode,
|
||||
):
|
||||
if self._abort_requested():
|
||||
logger.info("Run %s abort requested - stopping", self.run_id)
|
||||
break
|
||||
if should_filter_event(single_mode, chunk):
|
||||
continue
|
||||
await self.bridge.publish(
|
||||
self.run_id,
|
||||
external_stream_event_name(single_mode),
|
||||
serialize(chunk, mode=single_mode),
|
||||
)
|
||||
return
|
||||
|
||||
async for item in self._agent.astream(
|
||||
self.graph_input,
|
||||
config=runnable_config,
|
||||
stream_mode=self._lg_modes,
|
||||
subgraphs=self.stream_subgraphs,
|
||||
):
|
||||
if self._abort_requested():
|
||||
logger.info("Run %s abort requested - stopping", self.run_id)
|
||||
break
|
||||
|
||||
mode, chunk = unpack_stream_item(item, self._lg_modes, stream_subgraphs=self.stream_subgraphs)
|
||||
if mode is None:
|
||||
continue
|
||||
if should_filter_event(mode, chunk):
|
||||
continue
|
||||
await self.bridge.publish(
|
||||
self.run_id,
|
||||
external_stream_event_name(mode),
|
||||
serialize(chunk, mode=mode),
|
||||
)
|
||||
|
||||
async def _finish_after_stream(self) -> None:
|
||||
if self._abort_requested():
|
||||
action = self._abort_action()
|
||||
cancel_mode = "rollback" if action == "rollback" else "interrupt"
|
||||
await self._finish_aborted(cancel_mode)
|
||||
return
|
||||
|
||||
await self._finish_success()
|
||||
|
||||
async def _emit_final_thread_status(self) -> None:
|
||||
final_thread_status = self._map_run_status_to_thread_status(self.result.status)
|
||||
await self._emit(
|
||||
LifecycleEventType.THREAD_STATUS_UPDATED,
|
||||
{"status": final_thread_status},
|
||||
)
|
||||
|
||||
async def run(self) -> RunResult:
|
||||
try:
|
||||
await self._start()
|
||||
await self._prepare()
|
||||
await self._stream()
|
||||
await self._finish_after_stream()
|
||||
except asyncio.CancelledError:
|
||||
await self._finish_aborted("task_cancelled")
|
||||
except Exception as exc:
|
||||
await self._finish_failed(exc)
|
||||
finally:
|
||||
await self._emit_final_thread_status()
|
||||
if self._callback_artifacts is not None:
|
||||
await self._callback_artifacts.flush()
|
||||
await self.bridge.cleanup(self.run_id)
|
||||
|
||||
return self.result
|
||||
|
||||
|
||||
__all__ = ["_RunExecution"]
|
||||
@@ -0,0 +1,93 @@
|
||||
"""Execution-local stream processing helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StreamItem:
|
||||
"""Normalized stream item from LangGraph."""
|
||||
|
||||
mode: str
|
||||
chunk: Any
|
||||
|
||||
|
||||
_FILTERED_NODES = frozenset({"__start__", "__end__"})
|
||||
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}
|
||||
|
||||
|
||||
def normalize_stream_modes(requested_modes: list[str] | None) -> list[str]:
|
||||
"""Normalize requested stream modes to valid LangGraph modes."""
|
||||
input_modes: list[str] = list(requested_modes or ["values"])
|
||||
|
||||
lg_modes: list[str] = []
|
||||
for mode in input_modes:
|
||||
if mode == "messages-tuple":
|
||||
lg_modes.append("messages")
|
||||
elif mode == "events":
|
||||
logger.info("'events' stream_mode not supported (requires astream_events). Skipping.")
|
||||
continue
|
||||
elif mode in _VALID_LG_MODES:
|
||||
lg_modes.append(mode)
|
||||
|
||||
if not lg_modes:
|
||||
lg_modes = ["values"]
|
||||
|
||||
seen: set[str] = set()
|
||||
deduped: list[str] = []
|
||||
for mode in lg_modes:
|
||||
if mode not in seen:
|
||||
seen.add(mode)
|
||||
deduped.append(mode)
|
||||
|
||||
return deduped
|
||||
|
||||
|
||||
def unpack_stream_item(
|
||||
item: Any,
|
||||
lg_modes: list[str],
|
||||
*,
|
||||
stream_subgraphs: bool,
|
||||
) -> tuple[str | None, Any]:
|
||||
"""Unpack a multi-mode or subgraph stream item into ``(mode, chunk)``."""
|
||||
if stream_subgraphs:
|
||||
if isinstance(item, tuple) and len(item) == 3:
|
||||
_namespace, mode, chunk = item
|
||||
return str(mode), chunk
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
mode, chunk = item
|
||||
return str(mode), chunk
|
||||
return None, None
|
||||
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
mode, chunk = item
|
||||
return str(mode), chunk
|
||||
|
||||
return lg_modes[0] if lg_modes else None, item
|
||||
|
||||
|
||||
def should_filter_event(mode: str, chunk: Any) -> bool:
|
||||
"""Determine whether a stream event should be filtered before publish."""
|
||||
if mode == "updates" and isinstance(chunk, dict):
|
||||
node_names = set(chunk.keys())
|
||||
if node_names & _FILTERED_NODES:
|
||||
return True
|
||||
|
||||
if mode == "messages" and isinstance(chunk, tuple) and len(chunk) == 2:
|
||||
_, metadata = chunk
|
||||
if isinstance(metadata, dict):
|
||||
node = metadata.get("langgraph_node", "")
|
||||
if node in _FILTERED_NODES:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def external_stream_event_name(mode: str) -> str:
|
||||
"""Map LangGraph internal modes to the external SSE event contract."""
|
||||
return mode
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Active execution handle management for runs domain."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ...types import CancelAction
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunHandle:
|
||||
"""In-process control handle for an active run."""
|
||||
|
||||
run_id: str
|
||||
task: asyncio.Task[Any] | None = None
|
||||
cancel_event: asyncio.Event = field(default_factory=asyncio.Event)
|
||||
cancel_action: CancelAction = "interrupt"
|
||||
|
||||
|
||||
class RunSupervisor:
|
||||
"""Own and control active run handles within the current process."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._handles: dict[str, RunHandle] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def launch(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
runner: Callable[[RunHandle], Awaitable[Any]],
|
||||
) -> RunHandle:
|
||||
"""Create a handle and start a background task for it."""
|
||||
handle = RunHandle(run_id=run_id)
|
||||
|
||||
async with self._lock:
|
||||
if run_id in self._handles:
|
||||
raise RuntimeError(f"Run {run_id} is already active")
|
||||
self._handles[run_id] = handle
|
||||
|
||||
task = asyncio.create_task(runner(handle))
|
||||
handle.task = task
|
||||
task.add_done_callback(lambda _: asyncio.create_task(self.cleanup(run_id)))
|
||||
return handle
|
||||
|
||||
async def cancel(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
action: CancelAction = "interrupt",
|
||||
) -> bool:
|
||||
"""Signal cancellation for an active handle."""
|
||||
async with self._lock:
|
||||
handle = self._handles.get(run_id)
|
||||
if handle is None:
|
||||
return False
|
||||
|
||||
handle.cancel_action = action
|
||||
handle.cancel_event.set()
|
||||
if handle.task is not None and not handle.task.done():
|
||||
handle.task.cancel()
|
||||
|
||||
return True
|
||||
|
||||
def get_handle(self, run_id: str) -> RunHandle | None:
|
||||
"""Return the active handle for a run, if any."""
|
||||
return self._handles.get(run_id)
|
||||
|
||||
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
|
||||
"""Remove a handle after optional delay."""
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
async with self._lock:
|
||||
self._handles.pop(run_id, None)
|
||||
@@ -0,0 +1,253 @@
|
||||
"""In-memory run registry with optional persistent RunStore backing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from ..types import RunStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(UTC).isoformat()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunRecord:
|
||||
"""Mutable record for a single run."""
|
||||
|
||||
run_id: str
|
||||
thread_id: str
|
||||
assistant_id: str | None
|
||||
status: RunStatus
|
||||
on_disconnect: Literal["cancel", "continue"]
|
||||
multitask_strategy: str = "reject"
|
||||
metadata: dict = field(default_factory=dict)
|
||||
kwargs: dict = field(default_factory=dict)
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
task: asyncio.Task | None = field(default=None, repr=False)
|
||||
abort_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
|
||||
abort_action: str = "interrupt"
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class RunManager:
|
||||
"""In-memory run registry with optional persistent RunStore backing.
|
||||
|
||||
All mutations are protected by an asyncio lock. When a ``store`` is
|
||||
provided, serializable metadata is also persisted to the store so
|
||||
that run history survives process restarts.
|
||||
"""
|
||||
|
||||
def __init__(self, store: Any | None = None) -> None:
|
||||
self._runs: dict[str, RunRecord] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._store = store
|
||||
|
||||
async def _persist_to_store(self, record: RunRecord, *, follow_up_to_run_id: str | None = None) -> None:
|
||||
"""Best-effort persist run record to backing store."""
|
||||
if self._store is None:
|
||||
return
|
||||
try:
|
||||
await self._store.put(
|
||||
record.run_id,
|
||||
thread_id=record.thread_id,
|
||||
assistant_id=record.assistant_id,
|
||||
status=record.status.value,
|
||||
multitask_strategy=record.multitask_strategy,
|
||||
metadata=record.metadata or {},
|
||||
kwargs=record.kwargs or {},
|
||||
created_at=record.created_at,
|
||||
follow_up_to_run_id=follow_up_to_run_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
|
||||
|
||||
async def update_run_completion(self, run_id: str, **kwargs) -> None:
|
||||
"""Persist token usage and completion data to the backing store."""
|
||||
if self._store is not None:
|
||||
try:
|
||||
await self._store.update_run_completion(run_id, **kwargs)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
|
||||
|
||||
async def create(
|
||||
self,
|
||||
thread_id: str,
|
||||
assistant_id: str | None = None,
|
||||
*,
|
||||
on_disconnect: Literal["cancel", "continue"] = "cancel",
|
||||
metadata: dict | None = None,
|
||||
kwargs: dict | None = None,
|
||||
multitask_strategy: str = "reject",
|
||||
follow_up_to_run_id: str | None = None,
|
||||
) -> RunRecord:
|
||||
"""Create a new pending run and register it."""
|
||||
run_id = str(uuid.uuid4())
|
||||
now = _now_iso()
|
||||
record = RunRecord(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
status=RunStatus.pending,
|
||||
on_disconnect=on_disconnect,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata=metadata or {},
|
||||
kwargs=kwargs or {},
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
async with self._lock:
|
||||
self._runs[run_id] = record
|
||||
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
|
||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||
return record
|
||||
|
||||
def get(self, run_id: str) -> RunRecord | None:
|
||||
"""Return a run record by ID, or ``None``."""
|
||||
return self._runs.get(run_id)
|
||||
|
||||
async def list_by_thread(self, thread_id: str) -> list[RunRecord]:
|
||||
"""Return all runs for a given thread, newest first."""
|
||||
async with self._lock:
|
||||
# Dict insertion order matches creation order, so reversing it gives
|
||||
# us deterministic newest-first results even when timestamps tie.
|
||||
return [r for r in reversed(self._runs.values()) if r.thread_id == thread_id]
|
||||
|
||||
async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
|
||||
"""Transition a run to a new status."""
|
||||
async with self._lock:
|
||||
record = self._runs.get(run_id)
|
||||
if record is None:
|
||||
logger.warning("set_status called for unknown run %s", run_id)
|
||||
return
|
||||
record.status = status
|
||||
record.updated_at = _now_iso()
|
||||
if error is not None:
|
||||
record.error = error
|
||||
if self._store is not None:
|
||||
try:
|
||||
await self._store.update_status(run_id, status.value, error=error)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist status update for run %s", run_id, exc_info=True)
|
||||
logger.info("Run %s -> %s", run_id, status.value)
|
||||
|
||||
async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool:
|
||||
"""Request cancellation of a run.
|
||||
|
||||
Args:
|
||||
run_id: The run ID to cancel.
|
||||
action: "interrupt" keeps checkpoint, "rollback" reverts to pre-run state.
|
||||
|
||||
Sets the abort event with the action reason and cancels the asyncio task.
|
||||
Returns ``True`` if the run was in-flight and cancellation was initiated.
|
||||
"""
|
||||
async with self._lock:
|
||||
record = self._runs.get(run_id)
|
||||
if record is None:
|
||||
return False
|
||||
if record.status not in (RunStatus.pending, RunStatus.running):
|
||||
return False
|
||||
record.abort_action = action
|
||||
record.abort_event.set()
|
||||
if record.task is not None and not record.task.done():
|
||||
record.task.cancel()
|
||||
record.status = RunStatus.interrupted
|
||||
record.updated_at = _now_iso()
|
||||
logger.info("Run %s cancelled (action=%s)", run_id, action)
|
||||
return True
|
||||
|
||||
async def create_or_reject(
|
||||
self,
|
||||
thread_id: str,
|
||||
assistant_id: str | None = None,
|
||||
*,
|
||||
on_disconnect: Literal["cancel", "continue"] = "cancel",
|
||||
metadata: dict | None = None,
|
||||
kwargs: dict | None = None,
|
||||
multitask_strategy: str = "reject",
|
||||
follow_up_to_run_id: str | None = None,
|
||||
) -> RunRecord:
|
||||
"""Atomically check for inflight runs and create a new one.
|
||||
|
||||
For ``reject`` strategy, raises ``ConflictError`` if thread
|
||||
already has a pending/running run. For ``interrupt``/``rollback``,
|
||||
cancels inflight runs before creating.
|
||||
|
||||
This method holds the lock across both the check and the insert,
|
||||
eliminating the TOCTOU race in separate ``has_inflight`` + ``create``.
|
||||
"""
|
||||
run_id = str(uuid.uuid4())
|
||||
now = _now_iso()
|
||||
|
||||
_supported_strategies = ("reject", "interrupt", "rollback")
|
||||
|
||||
async with self._lock:
|
||||
if multitask_strategy not in _supported_strategies:
|
||||
raise UnsupportedStrategyError(f"Multitask strategy '{multitask_strategy}' is not yet supported. Supported strategies: {', '.join(_supported_strategies)}")
|
||||
|
||||
inflight = [r for r in self._runs.values() if r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running)]
|
||||
|
||||
if multitask_strategy == "reject" and inflight:
|
||||
raise ConflictError(f"Thread {thread_id} already has an active run")
|
||||
|
||||
if multitask_strategy in ("interrupt", "rollback") and inflight:
|
||||
for r in inflight:
|
||||
r.abort_action = multitask_strategy
|
||||
r.abort_event.set()
|
||||
if r.task is not None and not r.task.done():
|
||||
r.task.cancel()
|
||||
r.status = RunStatus.interrupted
|
||||
r.updated_at = now
|
||||
logger.info(
|
||||
"Cancelled %d inflight run(s) on thread %s (strategy=%s)",
|
||||
len(inflight),
|
||||
thread_id,
|
||||
multitask_strategy,
|
||||
)
|
||||
|
||||
record = RunRecord(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
status=RunStatus.pending,
|
||||
on_disconnect=on_disconnect,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata=metadata or {},
|
||||
kwargs=kwargs or {},
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
self._runs[run_id] = record
|
||||
|
||||
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
|
||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||
return record
|
||||
|
||||
async def has_inflight(self, thread_id: str) -> bool:
|
||||
"""Return ``True`` if *thread_id* has a pending or running run."""
|
||||
async with self._lock:
|
||||
return any(r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running) for r in self._runs.values())
|
||||
|
||||
async def cleanup(self, run_id: str, *, delay: float = 300) -> None:
|
||||
"""Remove a run record after an optional delay."""
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
async with self._lock:
|
||||
self._runs.pop(run_id, None)
|
||||
logger.debug("Run record %s cleaned up", run_id)
|
||||
|
||||
|
||||
class ConflictError(Exception):
|
||||
"""Raised when multitask_strategy=reject and thread has inflight runs."""
|
||||
|
||||
|
||||
class UnsupportedStrategyError(Exception):
|
||||
"""Raised when a multitask_strategy value is not yet implemented."""
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Execution plan builder for runs domain."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from ..types import RunRecord, RunSpec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExecutionPlan:
|
||||
"""Normalized execution inputs derived from a run record and spec."""
|
||||
|
||||
record: RunRecord
|
||||
graph_input: dict[str, Any]
|
||||
runnable_config: dict[str, Any]
|
||||
stream_modes: list[str]
|
||||
stream_subgraphs: bool
|
||||
interrupt_before: list[str] | Literal["*"] | None
|
||||
interrupt_after: list[str] | Literal["*"] | None
|
||||
|
||||
|
||||
class ExecutionPlanner:
|
||||
"""Build executor-ready plans from public run specs."""
|
||||
|
||||
def build(self, record: RunRecord, spec: RunSpec) -> ExecutionPlan:
|
||||
return ExecutionPlan(
|
||||
record=record,
|
||||
graph_input=self._normalize_graph_input(spec.input),
|
||||
runnable_config=deepcopy(spec.runnable_config),
|
||||
stream_modes=list(spec.stream_modes),
|
||||
stream_subgraphs=spec.stream_subgraphs,
|
||||
interrupt_before=spec.interrupt_before,
|
||||
interrupt_after=spec.interrupt_after,
|
||||
)
|
||||
|
||||
def _normalize_graph_input(self, raw_input: dict[str, Any] | None) -> dict[str, Any]:
|
||||
if raw_input is None:
|
||||
return {}
|
||||
return deepcopy(raw_input)
|
||||
@@ -0,0 +1,146 @@
|
||||
"""In-memory run registry for runs domain state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from ..types import INFLIGHT_STATUSES, RunRecord, RunSpec, RunStatus
|
||||
|
||||
|
||||
class RunRegistry:
|
||||
"""In-memory source of truth for run records and their status."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._records: dict[str, RunRecord] = {}
|
||||
self._thread_index: dict[str, set[str]] = {} # thread_id -> set[run_id]
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def create(self, spec: RunSpec) -> RunRecord:
|
||||
"""Create a new RunRecord from RunSpec."""
|
||||
run_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
record = RunRecord(
|
||||
run_id=run_id,
|
||||
thread_id=spec.scope.thread_id,
|
||||
assistant_id=spec.assistant_id,
|
||||
status="pending",
|
||||
temporary=spec.scope.temporary,
|
||||
multitask_strategy=spec.multitask_strategy,
|
||||
metadata=dict(spec.metadata),
|
||||
follow_up_to_run_id=spec.follow_up_to_run_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
self._records[run_id] = record
|
||||
# Update thread index
|
||||
if spec.scope.thread_id not in self._thread_index:
|
||||
self._thread_index[spec.scope.thread_id] = set()
|
||||
self._thread_index[spec.scope.thread_id].add(run_id)
|
||||
|
||||
return record
|
||||
|
||||
def get(self, run_id: str) -> RunRecord | None:
|
||||
"""Get RunRecord by run_id."""
|
||||
return self._records.get(run_id)
|
||||
|
||||
async def list_by_thread(self, thread_id: str) -> list[RunRecord]:
|
||||
"""List all RunRecords for a thread."""
|
||||
async with self._lock:
|
||||
run_ids = self._thread_index.get(thread_id, set())
|
||||
return [self._records[rid] for rid in run_ids if rid in self._records]
|
||||
|
||||
async def set_status(
|
||||
self,
|
||||
run_id: str,
|
||||
status: RunStatus,
|
||||
*,
|
||||
error: str | None = None,
|
||||
started_at: str | None = None,
|
||||
ended_at: str | None = None,
|
||||
) -> None:
|
||||
"""Update run status and optional fields."""
|
||||
async with self._lock:
|
||||
record = self._records.get(run_id)
|
||||
if record is None:
|
||||
return
|
||||
|
||||
record.status = status
|
||||
record.updated_at = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
if error is not None:
|
||||
record.error = error
|
||||
if started_at is not None:
|
||||
record.started_at = started_at
|
||||
if ended_at is not None:
|
||||
record.ended_at = ended_at
|
||||
|
||||
async def has_inflight(self, thread_id: str) -> bool:
|
||||
"""Check if thread has any inflight runs."""
|
||||
async with self._lock:
|
||||
run_ids = self._thread_index.get(thread_id, set())
|
||||
for rid in run_ids:
|
||||
record = self._records.get(rid)
|
||||
if record and record.status in INFLIGHT_STATUSES:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def interrupt_inflight(self, thread_id: str) -> list[str]:
|
||||
"""
|
||||
Mark all inflight runs for a thread as interrupted.
|
||||
|
||||
Returns list of interrupted run_ids.
|
||||
"""
|
||||
interrupted: list[str] = []
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
async with self._lock:
|
||||
run_ids = self._thread_index.get(thread_id, set())
|
||||
for rid in run_ids:
|
||||
record = self._records.get(rid)
|
||||
if record and record.status in INFLIGHT_STATUSES:
|
||||
record.status = "interrupted"
|
||||
record.updated_at = now
|
||||
record.ended_at = now
|
||||
interrupted.append(rid)
|
||||
|
||||
return interrupted
|
||||
|
||||
async def update_metadata(self, run_id: str, metadata: dict[str, Any]) -> None:
|
||||
"""Update run metadata."""
|
||||
async with self._lock:
|
||||
record = self._records.get(run_id)
|
||||
if record is not None:
|
||||
record.metadata.update(metadata)
|
||||
record.updated_at = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
async def delete(self, run_id: str) -> bool:
|
||||
"""Delete a run record. Returns True if deleted."""
|
||||
async with self._lock:
|
||||
record = self._records.pop(run_id, None)
|
||||
if record is None:
|
||||
return False
|
||||
|
||||
# Update thread index
|
||||
thread_runs = self._thread_index.get(record.thread_id)
|
||||
if thread_runs:
|
||||
thread_runs.discard(run_id)
|
||||
|
||||
return True
|
||||
|
||||
def count(self) -> int:
|
||||
"""Return total number of records."""
|
||||
return len(self._records)
|
||||
|
||||
def count_by_status(self, status: RunStatus) -> int:
|
||||
"""Return count of records with given status."""
|
||||
return sum(1 for r in self._records.values() if r.status == status)
|
||||
|
||||
|
||||
# Compatibility alias during the refactor.
|
||||
RuntimeRunRegistry = RunRegistry
|
||||
@@ -0,0 +1,76 @@
|
||||
"""Internal run stream adapter over StreamBridge."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.runtime.stream_bridge import JSONValue, StreamBridge, StreamEvent
|
||||
|
||||
from deerflow.runtime.stream_bridge import StreamStatus
|
||||
|
||||
|
||||
class RunStreamService:
|
||||
"""Thin runs-domain adapter over the harness stream bridge contract."""
|
||||
|
||||
def __init__(self, bridge: "StreamBridge") -> None:
|
||||
self._bridge = bridge
|
||||
|
||||
async def publish_event(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
event: str,
|
||||
data: "JSONValue",
|
||||
) -> str:
|
||||
"""Publish a replayable run event."""
|
||||
return await self._bridge.publish(run_id, event, data)
|
||||
|
||||
async def publish_end(self, run_id: str) -> str:
|
||||
"""Publish a successful terminal signal."""
|
||||
return await self._bridge.publish_terminal(run_id, StreamStatus.ENDED)
|
||||
|
||||
async def publish_cancelled(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
data: "JSONValue" = None,
|
||||
) -> str:
|
||||
"""Publish a cancelled terminal signal."""
|
||||
return await self._bridge.publish_terminal(
|
||||
run_id,
|
||||
StreamStatus.CANCELLED,
|
||||
data,
|
||||
)
|
||||
|
||||
async def publish_error(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
data: "JSONValue",
|
||||
) -> str:
|
||||
"""Publish a failed terminal signal."""
|
||||
return await self._bridge.publish_terminal(
|
||||
run_id,
|
||||
StreamStatus.ERRORED,
|
||||
data,
|
||||
)
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
last_event_id: str | None = None,
|
||||
heartbeat_interval: float = 15.0,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
"""Subscribe to a run stream with resume support."""
|
||||
return self._bridge.subscribe(
|
||||
run_id,
|
||||
last_event_id=last_event_id,
|
||||
heartbeat_interval=heartbeat_interval,
|
||||
)
|
||||
|
||||
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
|
||||
"""Release per-run bridge resources after completion."""
|
||||
await self._bridge.cleanup(run_id, delay=delay)
|
||||
@@ -0,0 +1,95 @@
|
||||
"""Internal run wait helpers based on stream events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from deerflow.runtime.stream_bridge import StreamEvent
|
||||
|
||||
from .streams import RunStreamService
|
||||
|
||||
|
||||
class WaitTimeoutError(TimeoutError):
|
||||
"""Raised when wait times out."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class WaitErrorResult:
|
||||
"""Represents an error result from wait."""
|
||||
|
||||
def __init__(self, error: str, details: dict[str, Any] | None = None) -> None:
|
||||
self.error = error
|
||||
self.details = details or {}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {"error": self.error, **self.details}
|
||||
|
||||
|
||||
class RunWaitService:
|
||||
"""
|
||||
Wait service for runs domain.
|
||||
|
||||
Based on RunStreamService.subscribe(), implements wait semantics.
|
||||
|
||||
Phase 1 behavior:
|
||||
- Records last 'values' event
|
||||
- On 'error', returns unified error structure
|
||||
- On 'end' only, returns last values
|
||||
"""
|
||||
|
||||
TERMINAL_EVENTS = frozenset({"end", "error", "cancel"})
|
||||
|
||||
def __init__(self, stream_service: RunStreamService) -> None:
|
||||
self._stream_service = stream_service
|
||||
|
||||
async def wait_for_terminal(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
last_event_id: str | None = None,
|
||||
) -> StreamEvent | None:
|
||||
"""Block until the next terminal event for a run is observed."""
|
||||
async for event in self._stream_service.subscribe(
|
||||
run_id,
|
||||
last_event_id=last_event_id,
|
||||
):
|
||||
if event.event in self.TERMINAL_EVENTS:
|
||||
return event
|
||||
|
||||
return None
|
||||
|
||||
async def wait_for_values_or_error(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
last_event_id: str | None = None,
|
||||
) -> dict[str, Any] | WaitErrorResult | None:
|
||||
"""
|
||||
Wait for run to complete and return final values or error.
|
||||
|
||||
Returns:
|
||||
- dict: Final values if successful
|
||||
- WaitErrorResult: If run failed
|
||||
- None: If no values were produced
|
||||
"""
|
||||
last_values: dict[str, Any] | None = None
|
||||
|
||||
async for event in self._stream_service.subscribe(
|
||||
run_id,
|
||||
last_event_id=last_event_id,
|
||||
):
|
||||
if event.event == "values":
|
||||
last_values = event.data
|
||||
|
||||
elif event.event == "error":
|
||||
return WaitErrorResult(
|
||||
error=str(event.data) if event.data else "Unknown error",
|
||||
details={"run_id": run_id},
|
||||
)
|
||||
|
||||
elif event.event in self.TERMINAL_EVENTS:
|
||||
# Stream ended, return last values
|
||||
break
|
||||
|
||||
return last_values
|
||||
Reference in New Issue
Block a user