mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 08:55:59 +00:00
refactor(runtime): restructure runs module with new execution architecture
Major refactoring of deerflow/runtime/: - runs/callbacks/ - new callback system (builder, events, title, tokens) - runs/internal/ - execution internals (executor, supervisor, stream_logic, registry) - runs/internal/execution/ - execution artifacts and events handling - runs/facade.py - high-level run facade - runs/observer.py - run observation protocol - runs/types.py - type definitions - runs/store/ - simplified store interfaces (create, delete, query, event) Refactor stream_bridge/: - Replace old providers with contract.py and exceptions.py - Remove async_provider.py, base.py, memory.py Add documentation: - README.md and README_zh.md for runtime module Remove deprecated: - manager.py moved to internal/ - worker.py, schemas.py - user_context.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Internal execution components for runs domain."""
|
||||
@@ -0,0 +1,64 @@
|
||||
"""Execution preparation helpers for a single run."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.runtime.stream_bridge import StreamBridge
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunBuildArtifacts:
|
||||
"""Assembled agent runtime pieces for a single run."""
|
||||
|
||||
agent: Any
|
||||
runnable_config: dict[str, Any]
|
||||
reference_store: Any | None = None
|
||||
|
||||
|
||||
def build_run_artifacts(
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
checkpointer: Any | None,
|
||||
store: Any | None,
|
||||
agent_factory: Any,
|
||||
config: dict[str, Any],
|
||||
bridge: StreamBridge,
|
||||
interrupt_before: list[str] | None = None,
|
||||
interrupt_after: list[str] | None = None,
|
||||
callbacks: list[BaseCallbackHandler] | None = None,
|
||||
) -> RunBuildArtifacts:
|
||||
"""Assemble all components needed for agent execution."""
|
||||
runtime = Runtime(context={"thread_id": thread_id}, store=store)
|
||||
if "context" in config and isinstance(config["context"], dict):
|
||||
config["context"].setdefault("thread_id", thread_id)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
config_callbacks = config.setdefault("callbacks", [])
|
||||
if callbacks:
|
||||
config_callbacks.extend(callbacks)
|
||||
|
||||
runnable_config = RunnableConfig(**config)
|
||||
agent = agent_factory(config=runnable_config)
|
||||
|
||||
if checkpointer is not None:
|
||||
agent.checkpointer = checkpointer
|
||||
if store is not None:
|
||||
agent.store = store
|
||||
|
||||
if interrupt_before:
|
||||
agent.interrupt_before_nodes = interrupt_before
|
||||
if interrupt_after:
|
||||
agent.interrupt_after_nodes = interrupt_after
|
||||
|
||||
return RunBuildArtifacts(
|
||||
agent=agent,
|
||||
runnable_config=dict(runnable_config),
|
||||
reference_store=store,
|
||||
)
|
||||
@@ -0,0 +1,45 @@
|
||||
"""Lifecycle event helpers for run execution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from ...observer import LifecycleEventType, RunLifecycleEvent, RunObserver
|
||||
|
||||
|
||||
class RunEventEmitter:
|
||||
"""Build and dispatch lifecycle events for a single run."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
observer: RunObserver,
|
||||
) -> None:
|
||||
self._run_id = run_id
|
||||
self._thread_id = thread_id
|
||||
self._observer = observer
|
||||
self._sequence = 0
|
||||
|
||||
@property
|
||||
def sequence(self) -> int:
|
||||
return self._sequence
|
||||
|
||||
async def emit(
|
||||
self,
|
||||
event_type: LifecycleEventType,
|
||||
payload: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
self._sequence += 1
|
||||
event = RunLifecycleEvent(
|
||||
event_id=f"{self._run_id}:{event_type.value}:{self._sequence}",
|
||||
event_type=event_type,
|
||||
run_id=self._run_id,
|
||||
thread_id=self._thread_id,
|
||||
sequence=self._sequence,
|
||||
occurred_at=datetime.now(UTC),
|
||||
payload=payload or {},
|
||||
)
|
||||
await self._observer.on_event(event)
|
||||
@@ -0,0 +1,376 @@
|
||||
"""Single-run execution orchestrator and execution-local helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from deerflow.runtime.serialization import serialize
|
||||
from deerflow.runtime.stream_bridge import StreamBridge, StreamStatus
|
||||
|
||||
from ...callbacks.builder import RunCallbackArtifacts, build_run_callbacks
|
||||
from ...observer import LifecycleEventType, RunObserver, RunResult
|
||||
from ...store import RunEventStore
|
||||
from ...types import RunStatus
|
||||
from .artifacts import build_run_artifacts
|
||||
from .events import RunEventEmitter
|
||||
from .stream_logic import external_stream_event_name, normalize_stream_modes, should_filter_event, unpack_stream_item
|
||||
from .supervisor import RunHandle
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _RunExecution:
|
||||
"""Encapsulate the lifecycle of a single run."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
bridge: StreamBridge,
|
||||
run_manager: Any,
|
||||
record: Any,
|
||||
checkpointer: Any | None = None,
|
||||
store: Any | None = None,
|
||||
event_store: RunEventStore | None = None,
|
||||
ctx: Any | None = None,
|
||||
agent_factory: Any,
|
||||
graph_input: dict,
|
||||
config: dict,
|
||||
observer: RunObserver,
|
||||
stream_modes: list[str] | None,
|
||||
stream_subgraphs: bool,
|
||||
interrupt_before: list[str] | Literal["*"] | None,
|
||||
interrupt_after: list[str] | Literal["*"] | None,
|
||||
handle: RunHandle | None = None,
|
||||
) -> None:
|
||||
if ctx is not None:
|
||||
checkpointer = getattr(ctx, "checkpointer", checkpointer)
|
||||
store = getattr(ctx, "store", store)
|
||||
|
||||
self.bridge = bridge
|
||||
self.run_manager = run_manager
|
||||
self.record = record
|
||||
self.checkpointer = checkpointer
|
||||
self.store = store
|
||||
self.event_store = event_store
|
||||
self.agent_factory = agent_factory
|
||||
self.graph_input = graph_input
|
||||
self.config = config
|
||||
self.observer = observer
|
||||
self.stream_modes = stream_modes
|
||||
self.stream_subgraphs = stream_subgraphs
|
||||
self.interrupt_before = interrupt_before
|
||||
self.interrupt_after = interrupt_after
|
||||
self.handle = handle
|
||||
|
||||
self.run_id = record.run_id
|
||||
self.thread_id = record.thread_id
|
||||
self._pre_run_checkpoint_id: str | None = None
|
||||
self._emitter = RunEventEmitter(
|
||||
run_id=self.run_id,
|
||||
thread_id=self.thread_id,
|
||||
observer=observer,
|
||||
)
|
||||
self.result = RunResult(
|
||||
run_id=self.run_id,
|
||||
thread_id=self.thread_id,
|
||||
status=RunStatus.pending,
|
||||
)
|
||||
self._agent: Any = None
|
||||
self._runnable_config: dict[str, Any] = {}
|
||||
self._lg_modes: list[str] = []
|
||||
self._callback_artifacts: RunCallbackArtifacts | None = None
|
||||
|
||||
@property
|
||||
def _event_sequence(self) -> int:
|
||||
return self._emitter.sequence
|
||||
|
||||
async def _emit(
|
||||
self,
|
||||
event_type: LifecycleEventType,
|
||||
payload: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
await self._emitter.emit(event_type, payload)
|
||||
|
||||
async def _start(self) -> None:
|
||||
await self.run_manager.set_status(self.run_id, RunStatus.running)
|
||||
await self._emit(LifecycleEventType.RUN_STARTED, {})
|
||||
|
||||
human_msg = self._extract_human_message()
|
||||
if human_msg is not None:
|
||||
await self._emit(
|
||||
LifecycleEventType.HUMAN_MESSAGE,
|
||||
{"message": human_msg.model_dump()},
|
||||
)
|
||||
|
||||
await self._capture_pre_run_checkpoint()
|
||||
await self.bridge.publish(
|
||||
self.run_id,
|
||||
"metadata",
|
||||
{"run_id": self.run_id, "thread_id": self.thread_id},
|
||||
)
|
||||
|
||||
def _extract_human_message(self) -> Any:
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
messages = self.graph_input.get("messages")
|
||||
if not messages:
|
||||
return None
|
||||
last = messages[-1] if isinstance(messages, list) else messages
|
||||
if isinstance(last, HumanMessage):
|
||||
return last
|
||||
if isinstance(last, str):
|
||||
return HumanMessage(content=last) if last else None
|
||||
if hasattr(last, "content"):
|
||||
return HumanMessage(content=last.content)
|
||||
if isinstance(last, dict):
|
||||
content = last.get("content", "")
|
||||
return HumanMessage(content=content) if content else None
|
||||
return None
|
||||
|
||||
async def _capture_pre_run_checkpoint(self) -> None:
|
||||
try:
|
||||
config_for_check = {"configurable": {"thread_id": self.thread_id, "checkpoint_ns": ""}}
|
||||
ckpt_tuple = await self.checkpointer.aget_tuple(config_for_check)
|
||||
if ckpt_tuple is not None:
|
||||
self._pre_run_checkpoint_id = (
|
||||
getattr(ckpt_tuple, "config", {})
|
||||
.get("configurable", {})
|
||||
.get("checkpoint_id")
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Could not get pre-run checkpoint_id for run %s", self.run_id)
|
||||
|
||||
async def _prepare(self) -> None:
|
||||
config = dict(self.config)
|
||||
existing_callbacks = config.pop("callbacks", [])
|
||||
if existing_callbacks is None:
|
||||
existing_callbacks = []
|
||||
elif not isinstance(existing_callbacks, list):
|
||||
existing_callbacks = [existing_callbacks]
|
||||
|
||||
self._callback_artifacts = build_run_callbacks(
|
||||
record=self.record,
|
||||
graph_input=self.graph_input,
|
||||
event_store=self.event_store,
|
||||
existing_callbacks=existing_callbacks,
|
||||
)
|
||||
|
||||
artifacts = build_run_artifacts(
|
||||
thread_id=self.thread_id,
|
||||
run_id=self.run_id,
|
||||
checkpointer=self.checkpointer,
|
||||
store=self.store,
|
||||
agent_factory=self.agent_factory,
|
||||
config=config,
|
||||
bridge=self.bridge,
|
||||
interrupt_before=self.interrupt_before,
|
||||
interrupt_after=self.interrupt_after,
|
||||
callbacks=self._callback_artifacts.callbacks,
|
||||
)
|
||||
|
||||
self._agent = artifacts.agent
|
||||
self._runnable_config = artifacts.runnable_config
|
||||
self._lg_modes = normalize_stream_modes(self.stream_modes)
|
||||
logger.info(
|
||||
"Run %s: streaming with modes %s (requested: %s)",
|
||||
self.run_id,
|
||||
self._lg_modes,
|
||||
self.stream_modes,
|
||||
)
|
||||
|
||||
async def _finish_success(self) -> None:
|
||||
await self.run_manager.set_status(self.run_id, RunStatus.success)
|
||||
await self.bridge.publish_terminal(self.run_id, StreamStatus.ENDED)
|
||||
self.result.status = RunStatus.success
|
||||
completion_data = self._completion_data()
|
||||
title = self._callback_title() or await self._extract_title_from_checkpoint()
|
||||
self.result.title = title
|
||||
self.result.completion_data = completion_data
|
||||
await self._emit(
|
||||
LifecycleEventType.RUN_COMPLETED,
|
||||
{
|
||||
"title": title,
|
||||
"completion_data": completion_data,
|
||||
},
|
||||
)
|
||||
|
||||
async def _finish_aborted(self, cancel_mode: str) -> None:
|
||||
payload = {
|
||||
"cancel_mode": cancel_mode,
|
||||
"pre_run_checkpoint_id": self._pre_run_checkpoint_id,
|
||||
"completion_data": self._completion_data(),
|
||||
}
|
||||
|
||||
if cancel_mode == "rollback":
|
||||
await self.run_manager.set_status(
|
||||
self.run_id,
|
||||
RunStatus.error,
|
||||
error="Rolled back by user",
|
||||
)
|
||||
await self.bridge.publish_terminal(
|
||||
self.run_id,
|
||||
StreamStatus.CANCELLED,
|
||||
{"cancel_mode": "rollback", "message": "Rolled back by user"},
|
||||
)
|
||||
self.result.status = RunStatus.error
|
||||
self.result.error = "Rolled back by user"
|
||||
logger.info("Run %s rolled back", self.run_id)
|
||||
else:
|
||||
await self.run_manager.set_status(self.run_id, RunStatus.interrupted)
|
||||
await self.bridge.publish_terminal(
|
||||
self.run_id,
|
||||
StreamStatus.CANCELLED,
|
||||
{"cancel_mode": cancel_mode},
|
||||
)
|
||||
self.result.status = RunStatus.interrupted
|
||||
logger.info("Run %s cancelled (mode=%s)", self.run_id, cancel_mode)
|
||||
|
||||
await self._emit(LifecycleEventType.RUN_CANCELLED, payload)
|
||||
|
||||
async def _finish_failed(self, exc: Exception) -> None:
|
||||
error_msg = str(exc)
|
||||
logger.exception("Run %s failed: %s", self.run_id, error_msg)
|
||||
|
||||
await self.run_manager.set_status(self.run_id, RunStatus.error, error=error_msg)
|
||||
await self.bridge.publish_terminal(
|
||||
self.run_id,
|
||||
StreamStatus.ERRORED,
|
||||
{"message": error_msg, "name": type(exc).__name__},
|
||||
)
|
||||
self.result.status = RunStatus.error
|
||||
self.result.error = error_msg
|
||||
|
||||
await self._emit(
|
||||
LifecycleEventType.RUN_FAILED,
|
||||
{
|
||||
"error": error_msg,
|
||||
"error_type": type(exc).__name__,
|
||||
"completion_data": self._completion_data(),
|
||||
},
|
||||
)
|
||||
|
||||
def _completion_data(self) -> dict[str, object]:
|
||||
if self._callback_artifacts is None:
|
||||
return {}
|
||||
return self._callback_artifacts.completion_data().to_dict()
|
||||
|
||||
def _callback_title(self) -> str | None:
|
||||
if self._callback_artifacts is None:
|
||||
return None
|
||||
return self._callback_artifacts.title()
|
||||
|
||||
async def _extract_title_from_checkpoint(self) -> str | None:
|
||||
if self.checkpointer is None:
|
||||
return None
|
||||
try:
|
||||
ckpt_config = {"configurable": {"thread_id": self.thread_id, "checkpoint_ns": ""}}
|
||||
ckpt_tuple = await self.checkpointer.aget_tuple(ckpt_config)
|
||||
if ckpt_tuple is not None:
|
||||
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
|
||||
return ckpt.get("channel_values", {}).get("title")
|
||||
except Exception:
|
||||
logger.debug("Failed to extract title from checkpoint for thread %s", self.thread_id)
|
||||
return None
|
||||
|
||||
def _map_run_status_to_thread_status(self, status: RunStatus) -> str:
|
||||
if status == RunStatus.success:
|
||||
return "idle"
|
||||
if status == RunStatus.interrupted:
|
||||
return "interrupted"
|
||||
if status in (RunStatus.error, RunStatus.timeout):
|
||||
return "error"
|
||||
return "running"
|
||||
|
||||
def _abort_requested(self) -> bool:
|
||||
if self.handle is not None:
|
||||
return self.handle.cancel_event.is_set()
|
||||
return self.record.abort_event.is_set()
|
||||
|
||||
def _abort_action(self) -> str:
|
||||
if self.handle is not None:
|
||||
return self.handle.cancel_action
|
||||
return self.record.abort_action
|
||||
|
||||
async def _stream(self) -> None:
|
||||
runnable_config = RunnableConfig(**self._runnable_config)
|
||||
|
||||
if len(self._lg_modes) == 1 and not self.stream_subgraphs:
|
||||
single_mode = self._lg_modes[0]
|
||||
async for chunk in self._agent.astream(
|
||||
self.graph_input,
|
||||
config=runnable_config,
|
||||
stream_mode=single_mode,
|
||||
):
|
||||
if self._abort_requested():
|
||||
logger.info("Run %s abort requested - stopping", self.run_id)
|
||||
break
|
||||
if should_filter_event(single_mode, chunk):
|
||||
continue
|
||||
await self.bridge.publish(
|
||||
self.run_id,
|
||||
external_stream_event_name(single_mode),
|
||||
serialize(chunk, mode=single_mode),
|
||||
)
|
||||
return
|
||||
|
||||
async for item in self._agent.astream(
|
||||
self.graph_input,
|
||||
config=runnable_config,
|
||||
stream_mode=self._lg_modes,
|
||||
subgraphs=self.stream_subgraphs,
|
||||
):
|
||||
if self._abort_requested():
|
||||
logger.info("Run %s abort requested - stopping", self.run_id)
|
||||
break
|
||||
|
||||
mode, chunk = unpack_stream_item(item, self._lg_modes, stream_subgraphs=self.stream_subgraphs)
|
||||
if mode is None:
|
||||
continue
|
||||
if should_filter_event(mode, chunk):
|
||||
continue
|
||||
await self.bridge.publish(
|
||||
self.run_id,
|
||||
external_stream_event_name(mode),
|
||||
serialize(chunk, mode=mode),
|
||||
)
|
||||
|
||||
async def _finish_after_stream(self) -> None:
|
||||
if self._abort_requested():
|
||||
action = self._abort_action()
|
||||
cancel_mode = "rollback" if action == "rollback" else "interrupt"
|
||||
await self._finish_aborted(cancel_mode)
|
||||
return
|
||||
|
||||
await self._finish_success()
|
||||
|
||||
async def _emit_final_thread_status(self) -> None:
|
||||
final_thread_status = self._map_run_status_to_thread_status(self.result.status)
|
||||
await self._emit(
|
||||
LifecycleEventType.THREAD_STATUS_UPDATED,
|
||||
{"status": final_thread_status},
|
||||
)
|
||||
|
||||
async def run(self) -> RunResult:
|
||||
try:
|
||||
await self._start()
|
||||
await self._prepare()
|
||||
await self._stream()
|
||||
await self._finish_after_stream()
|
||||
except asyncio.CancelledError:
|
||||
await self._finish_aborted("task_cancelled")
|
||||
except Exception as exc:
|
||||
await self._finish_failed(exc)
|
||||
finally:
|
||||
await self._emit_final_thread_status()
|
||||
if self._callback_artifacts is not None:
|
||||
await self._callback_artifacts.flush()
|
||||
await self.bridge.cleanup(self.run_id)
|
||||
|
||||
return self.result
|
||||
|
||||
|
||||
__all__ = ["_RunExecution"]
|
||||
@@ -0,0 +1,93 @@
|
||||
"""Execution-local stream processing helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StreamItem:
|
||||
"""Normalized stream item from LangGraph."""
|
||||
|
||||
mode: str
|
||||
chunk: Any
|
||||
|
||||
|
||||
_FILTERED_NODES = frozenset({"__start__", "__end__"})
|
||||
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}
|
||||
|
||||
|
||||
def normalize_stream_modes(requested_modes: list[str] | None) -> list[str]:
|
||||
"""Normalize requested stream modes to valid LangGraph modes."""
|
||||
input_modes: list[str] = list(requested_modes or ["values"])
|
||||
|
||||
lg_modes: list[str] = []
|
||||
for mode in input_modes:
|
||||
if mode == "messages-tuple":
|
||||
lg_modes.append("messages")
|
||||
elif mode == "events":
|
||||
logger.info("'events' stream_mode not supported (requires astream_events). Skipping.")
|
||||
continue
|
||||
elif mode in _VALID_LG_MODES:
|
||||
lg_modes.append(mode)
|
||||
|
||||
if not lg_modes:
|
||||
lg_modes = ["values"]
|
||||
|
||||
seen: set[str] = set()
|
||||
deduped: list[str] = []
|
||||
for mode in lg_modes:
|
||||
if mode not in seen:
|
||||
seen.add(mode)
|
||||
deduped.append(mode)
|
||||
|
||||
return deduped
|
||||
|
||||
|
||||
def unpack_stream_item(
|
||||
item: Any,
|
||||
lg_modes: list[str],
|
||||
*,
|
||||
stream_subgraphs: bool,
|
||||
) -> tuple[str | None, Any]:
|
||||
"""Unpack a multi-mode or subgraph stream item into ``(mode, chunk)``."""
|
||||
if stream_subgraphs:
|
||||
if isinstance(item, tuple) and len(item) == 3:
|
||||
_namespace, mode, chunk = item
|
||||
return str(mode), chunk
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
mode, chunk = item
|
||||
return str(mode), chunk
|
||||
return None, None
|
||||
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
mode, chunk = item
|
||||
return str(mode), chunk
|
||||
|
||||
return lg_modes[0] if lg_modes else None, item
|
||||
|
||||
|
||||
def should_filter_event(mode: str, chunk: Any) -> bool:
|
||||
"""Determine whether a stream event should be filtered before publish."""
|
||||
if mode == "updates" and isinstance(chunk, dict):
|
||||
node_names = set(chunk.keys())
|
||||
if node_names & _FILTERED_NODES:
|
||||
return True
|
||||
|
||||
if mode == "messages" and isinstance(chunk, tuple) and len(chunk) == 2:
|
||||
_, metadata = chunk
|
||||
if isinstance(metadata, dict):
|
||||
node = metadata.get("langgraph_node", "")
|
||||
if node in _FILTERED_NODES:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def external_stream_event_name(mode: str) -> str:
|
||||
"""Map LangGraph internal modes to the external SSE event contract."""
|
||||
return mode
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Active execution handle management for runs domain."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ...types import CancelAction
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunHandle:
|
||||
"""In-process control handle for an active run."""
|
||||
|
||||
run_id: str
|
||||
task: asyncio.Task[Any] | None = None
|
||||
cancel_event: asyncio.Event = field(default_factory=asyncio.Event)
|
||||
cancel_action: CancelAction = "interrupt"
|
||||
|
||||
|
||||
class RunSupervisor:
|
||||
"""Own and control active run handles within the current process."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._handles: dict[str, RunHandle] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def launch(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
runner: Callable[[RunHandle], Awaitable[Any]],
|
||||
) -> RunHandle:
|
||||
"""Create a handle and start a background task for it."""
|
||||
handle = RunHandle(run_id=run_id)
|
||||
|
||||
async with self._lock:
|
||||
if run_id in self._handles:
|
||||
raise RuntimeError(f"Run {run_id} is already active")
|
||||
self._handles[run_id] = handle
|
||||
|
||||
task = asyncio.create_task(runner(handle))
|
||||
handle.task = task
|
||||
task.add_done_callback(lambda _: asyncio.create_task(self.cleanup(run_id)))
|
||||
return handle
|
||||
|
||||
async def cancel(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
action: CancelAction = "interrupt",
|
||||
) -> bool:
|
||||
"""Signal cancellation for an active handle."""
|
||||
async with self._lock:
|
||||
handle = self._handles.get(run_id)
|
||||
if handle is None:
|
||||
return False
|
||||
|
||||
handle.cancel_action = action
|
||||
handle.cancel_event.set()
|
||||
if handle.task is not None and not handle.task.done():
|
||||
handle.task.cancel()
|
||||
|
||||
return True
|
||||
|
||||
def get_handle(self, run_id: str) -> RunHandle | None:
|
||||
"""Return the active handle for a run, if any."""
|
||||
return self._handles.get(run_id)
|
||||
|
||||
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
|
||||
"""Remove a handle after optional delay."""
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
async with self._lock:
|
||||
self._handles.pop(run_id, None)
|
||||
Reference in New Issue
Block a user