refactor(runtime): restructure runs module with new execution architecture

Major refactoring of deerflow/runtime/:
- runs/callbacks/ - new callback system (builder, events, title, tokens)
- runs/internal/ - execution internals (executor, supervisor, stream_logic, registry)
- runs/internal/execution/ - execution artifacts and events handling
- runs/facade.py - high-level run facade
- runs/observer.py - run observation protocol
- runs/types.py - type definitions
- runs/store/ - simplified store interfaces (create, delete, query, event)

Refactor stream_bridge/:
- Replace old providers with contract.py and exceptions.py
- Remove async_provider.py, base.py, memory.py

Add documentation:
- README.md and README_zh.md for runtime module

Remove deprecated:
- manager.py moved to internal/
- worker.py, schemas.py
- user_context.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-22 11:28:01 +08:00
parent 39a575617b
commit 9d0a42c1fb
43 changed files with 3928 additions and 1192 deletions
@@ -0,0 +1,4 @@
"""Internal runs implementation modules.
These modules are implementation details behind the public runs surface.
"""
@@ -0,0 +1 @@
"""Internal execution components for runs domain."""
@@ -0,0 +1,64 @@
"""Execution preparation helpers for a single run."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.runnables import RunnableConfig
from langgraph.runtime import Runtime
from deerflow.runtime.stream_bridge import StreamBridge
@dataclass
class RunBuildArtifacts:
"""Assembled agent runtime pieces for a single run."""
agent: Any
runnable_config: dict[str, Any]
reference_store: Any | None = None
def build_run_artifacts(
*,
thread_id: str,
run_id: str,
checkpointer: Any | None,
store: Any | None,
agent_factory: Any,
config: dict[str, Any],
bridge: StreamBridge,
interrupt_before: list[str] | None = None,
interrupt_after: list[str] | None = None,
callbacks: list[BaseCallbackHandler] | None = None,
) -> RunBuildArtifacts:
"""Assemble all components needed for agent execution."""
runtime = Runtime(context={"thread_id": thread_id}, store=store)
if "context" in config and isinstance(config["context"], dict):
config["context"].setdefault("thread_id", thread_id)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
config_callbacks = config.setdefault("callbacks", [])
if callbacks:
config_callbacks.extend(callbacks)
runnable_config = RunnableConfig(**config)
agent = agent_factory(config=runnable_config)
if checkpointer is not None:
agent.checkpointer = checkpointer
if store is not None:
agent.store = store
if interrupt_before:
agent.interrupt_before_nodes = interrupt_before
if interrupt_after:
agent.interrupt_after_nodes = interrupt_after
return RunBuildArtifacts(
agent=agent,
runnable_config=dict(runnable_config),
reference_store=store,
)
@@ -0,0 +1,45 @@
"""Lifecycle event helpers for run execution."""
from __future__ import annotations
from datetime import UTC, datetime
from typing import Any
from ...observer import LifecycleEventType, RunLifecycleEvent, RunObserver
class RunEventEmitter:
"""Build and dispatch lifecycle events for a single run."""
def __init__(
self,
*,
run_id: str,
thread_id: str,
observer: RunObserver,
) -> None:
self._run_id = run_id
self._thread_id = thread_id
self._observer = observer
self._sequence = 0
@property
def sequence(self) -> int:
return self._sequence
async def emit(
self,
event_type: LifecycleEventType,
payload: dict[str, Any] | None = None,
) -> None:
self._sequence += 1
event = RunLifecycleEvent(
event_id=f"{self._run_id}:{event_type.value}:{self._sequence}",
event_type=event_type,
run_id=self._run_id,
thread_id=self._thread_id,
sequence=self._sequence,
occurred_at=datetime.now(UTC),
payload=payload or {},
)
await self._observer.on_event(event)
@@ -0,0 +1,376 @@
"""Single-run execution orchestrator and execution-local helpers."""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Literal
from langchain_core.runnables import RunnableConfig
from deerflow.runtime.serialization import serialize
from deerflow.runtime.stream_bridge import StreamBridge, StreamStatus
from ...callbacks.builder import RunCallbackArtifacts, build_run_callbacks
from ...observer import LifecycleEventType, RunObserver, RunResult
from ...store import RunEventStore
from ...types import RunStatus
from .artifacts import build_run_artifacts
from .events import RunEventEmitter
from .stream_logic import external_stream_event_name, normalize_stream_modes, should_filter_event, unpack_stream_item
from .supervisor import RunHandle
logger = logging.getLogger(__name__)
class _RunExecution:
"""Encapsulate the lifecycle of a single run."""
def __init__(
self,
*,
bridge: StreamBridge,
run_manager: Any,
record: Any,
checkpointer: Any | None = None,
store: Any | None = None,
event_store: RunEventStore | None = None,
ctx: Any | None = None,
agent_factory: Any,
graph_input: dict,
config: dict,
observer: RunObserver,
stream_modes: list[str] | None,
stream_subgraphs: bool,
interrupt_before: list[str] | Literal["*"] | None,
interrupt_after: list[str] | Literal["*"] | None,
handle: RunHandle | None = None,
) -> None:
if ctx is not None:
checkpointer = getattr(ctx, "checkpointer", checkpointer)
store = getattr(ctx, "store", store)
self.bridge = bridge
self.run_manager = run_manager
self.record = record
self.checkpointer = checkpointer
self.store = store
self.event_store = event_store
self.agent_factory = agent_factory
self.graph_input = graph_input
self.config = config
self.observer = observer
self.stream_modes = stream_modes
self.stream_subgraphs = stream_subgraphs
self.interrupt_before = interrupt_before
self.interrupt_after = interrupt_after
self.handle = handle
self.run_id = record.run_id
self.thread_id = record.thread_id
self._pre_run_checkpoint_id: str | None = None
self._emitter = RunEventEmitter(
run_id=self.run_id,
thread_id=self.thread_id,
observer=observer,
)
self.result = RunResult(
run_id=self.run_id,
thread_id=self.thread_id,
status=RunStatus.pending,
)
self._agent: Any = None
self._runnable_config: dict[str, Any] = {}
self._lg_modes: list[str] = []
self._callback_artifacts: RunCallbackArtifacts | None = None
@property
def _event_sequence(self) -> int:
return self._emitter.sequence
async def _emit(
self,
event_type: LifecycleEventType,
payload: dict[str, Any] | None = None,
) -> None:
await self._emitter.emit(event_type, payload)
async def _start(self) -> None:
await self.run_manager.set_status(self.run_id, RunStatus.running)
await self._emit(LifecycleEventType.RUN_STARTED, {})
human_msg = self._extract_human_message()
if human_msg is not None:
await self._emit(
LifecycleEventType.HUMAN_MESSAGE,
{"message": human_msg.model_dump()},
)
await self._capture_pre_run_checkpoint()
await self.bridge.publish(
self.run_id,
"metadata",
{"run_id": self.run_id, "thread_id": self.thread_id},
)
def _extract_human_message(self) -> Any:
from langchain_core.messages import HumanMessage
messages = self.graph_input.get("messages")
if not messages:
return None
last = messages[-1] if isinstance(messages, list) else messages
if isinstance(last, HumanMessage):
return last
if isinstance(last, str):
return HumanMessage(content=last) if last else None
if hasattr(last, "content"):
return HumanMessage(content=last.content)
if isinstance(last, dict):
content = last.get("content", "")
return HumanMessage(content=content) if content else None
return None
async def _capture_pre_run_checkpoint(self) -> None:
try:
config_for_check = {"configurable": {"thread_id": self.thread_id, "checkpoint_ns": ""}}
ckpt_tuple = await self.checkpointer.aget_tuple(config_for_check)
if ckpt_tuple is not None:
self._pre_run_checkpoint_id = (
getattr(ckpt_tuple, "config", {})
.get("configurable", {})
.get("checkpoint_id")
)
except Exception:
logger.debug("Could not get pre-run checkpoint_id for run %s", self.run_id)
async def _prepare(self) -> None:
config = dict(self.config)
existing_callbacks = config.pop("callbacks", [])
if existing_callbacks is None:
existing_callbacks = []
elif not isinstance(existing_callbacks, list):
existing_callbacks = [existing_callbacks]
self._callback_artifacts = build_run_callbacks(
record=self.record,
graph_input=self.graph_input,
event_store=self.event_store,
existing_callbacks=existing_callbacks,
)
artifacts = build_run_artifacts(
thread_id=self.thread_id,
run_id=self.run_id,
checkpointer=self.checkpointer,
store=self.store,
agent_factory=self.agent_factory,
config=config,
bridge=self.bridge,
interrupt_before=self.interrupt_before,
interrupt_after=self.interrupt_after,
callbacks=self._callback_artifacts.callbacks,
)
self._agent = artifacts.agent
self._runnable_config = artifacts.runnable_config
self._lg_modes = normalize_stream_modes(self.stream_modes)
logger.info(
"Run %s: streaming with modes %s (requested: %s)",
self.run_id,
self._lg_modes,
self.stream_modes,
)
async def _finish_success(self) -> None:
await self.run_manager.set_status(self.run_id, RunStatus.success)
await self.bridge.publish_terminal(self.run_id, StreamStatus.ENDED)
self.result.status = RunStatus.success
completion_data = self._completion_data()
title = self._callback_title() or await self._extract_title_from_checkpoint()
self.result.title = title
self.result.completion_data = completion_data
await self._emit(
LifecycleEventType.RUN_COMPLETED,
{
"title": title,
"completion_data": completion_data,
},
)
async def _finish_aborted(self, cancel_mode: str) -> None:
payload = {
"cancel_mode": cancel_mode,
"pre_run_checkpoint_id": self._pre_run_checkpoint_id,
"completion_data": self._completion_data(),
}
if cancel_mode == "rollback":
await self.run_manager.set_status(
self.run_id,
RunStatus.error,
error="Rolled back by user",
)
await self.bridge.publish_terminal(
self.run_id,
StreamStatus.CANCELLED,
{"cancel_mode": "rollback", "message": "Rolled back by user"},
)
self.result.status = RunStatus.error
self.result.error = "Rolled back by user"
logger.info("Run %s rolled back", self.run_id)
else:
await self.run_manager.set_status(self.run_id, RunStatus.interrupted)
await self.bridge.publish_terminal(
self.run_id,
StreamStatus.CANCELLED,
{"cancel_mode": cancel_mode},
)
self.result.status = RunStatus.interrupted
logger.info("Run %s cancelled (mode=%s)", self.run_id, cancel_mode)
await self._emit(LifecycleEventType.RUN_CANCELLED, payload)
async def _finish_failed(self, exc: Exception) -> None:
error_msg = str(exc)
logger.exception("Run %s failed: %s", self.run_id, error_msg)
await self.run_manager.set_status(self.run_id, RunStatus.error, error=error_msg)
await self.bridge.publish_terminal(
self.run_id,
StreamStatus.ERRORED,
{"message": error_msg, "name": type(exc).__name__},
)
self.result.status = RunStatus.error
self.result.error = error_msg
await self._emit(
LifecycleEventType.RUN_FAILED,
{
"error": error_msg,
"error_type": type(exc).__name__,
"completion_data": self._completion_data(),
},
)
def _completion_data(self) -> dict[str, object]:
if self._callback_artifacts is None:
return {}
return self._callback_artifacts.completion_data().to_dict()
def _callback_title(self) -> str | None:
if self._callback_artifacts is None:
return None
return self._callback_artifacts.title()
async def _extract_title_from_checkpoint(self) -> str | None:
if self.checkpointer is None:
return None
try:
ckpt_config = {"configurable": {"thread_id": self.thread_id, "checkpoint_ns": ""}}
ckpt_tuple = await self.checkpointer.aget_tuple(ckpt_config)
if ckpt_tuple is not None:
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
return ckpt.get("channel_values", {}).get("title")
except Exception:
logger.debug("Failed to extract title from checkpoint for thread %s", self.thread_id)
return None
def _map_run_status_to_thread_status(self, status: RunStatus) -> str:
if status == RunStatus.success:
return "idle"
if status == RunStatus.interrupted:
return "interrupted"
if status in (RunStatus.error, RunStatus.timeout):
return "error"
return "running"
def _abort_requested(self) -> bool:
if self.handle is not None:
return self.handle.cancel_event.is_set()
return self.record.abort_event.is_set()
def _abort_action(self) -> str:
if self.handle is not None:
return self.handle.cancel_action
return self.record.abort_action
async def _stream(self) -> None:
runnable_config = RunnableConfig(**self._runnable_config)
if len(self._lg_modes) == 1 and not self.stream_subgraphs:
single_mode = self._lg_modes[0]
async for chunk in self._agent.astream(
self.graph_input,
config=runnable_config,
stream_mode=single_mode,
):
if self._abort_requested():
logger.info("Run %s abort requested - stopping", self.run_id)
break
if should_filter_event(single_mode, chunk):
continue
await self.bridge.publish(
self.run_id,
external_stream_event_name(single_mode),
serialize(chunk, mode=single_mode),
)
return
async for item in self._agent.astream(
self.graph_input,
config=runnable_config,
stream_mode=self._lg_modes,
subgraphs=self.stream_subgraphs,
):
if self._abort_requested():
logger.info("Run %s abort requested - stopping", self.run_id)
break
mode, chunk = unpack_stream_item(item, self._lg_modes, stream_subgraphs=self.stream_subgraphs)
if mode is None:
continue
if should_filter_event(mode, chunk):
continue
await self.bridge.publish(
self.run_id,
external_stream_event_name(mode),
serialize(chunk, mode=mode),
)
async def _finish_after_stream(self) -> None:
if self._abort_requested():
action = self._abort_action()
cancel_mode = "rollback" if action == "rollback" else "interrupt"
await self._finish_aborted(cancel_mode)
return
await self._finish_success()
async def _emit_final_thread_status(self) -> None:
final_thread_status = self._map_run_status_to_thread_status(self.result.status)
await self._emit(
LifecycleEventType.THREAD_STATUS_UPDATED,
{"status": final_thread_status},
)
async def run(self) -> RunResult:
try:
await self._start()
await self._prepare()
await self._stream()
await self._finish_after_stream()
except asyncio.CancelledError:
await self._finish_aborted("task_cancelled")
except Exception as exc:
await self._finish_failed(exc)
finally:
await self._emit_final_thread_status()
if self._callback_artifacts is not None:
await self._callback_artifacts.flush()
await self.bridge.cleanup(self.run_id)
return self.result
__all__ = ["_RunExecution"]
@@ -0,0 +1,93 @@
"""Execution-local stream processing helpers."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class StreamItem:
"""Normalized stream item from LangGraph."""
mode: str
chunk: Any
_FILTERED_NODES = frozenset({"__start__", "__end__"})
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}
def normalize_stream_modes(requested_modes: list[str] | None) -> list[str]:
"""Normalize requested stream modes to valid LangGraph modes."""
input_modes: list[str] = list(requested_modes or ["values"])
lg_modes: list[str] = []
for mode in input_modes:
if mode == "messages-tuple":
lg_modes.append("messages")
elif mode == "events":
logger.info("'events' stream_mode not supported (requires astream_events). Skipping.")
continue
elif mode in _VALID_LG_MODES:
lg_modes.append(mode)
if not lg_modes:
lg_modes = ["values"]
seen: set[str] = set()
deduped: list[str] = []
for mode in lg_modes:
if mode not in seen:
seen.add(mode)
deduped.append(mode)
return deduped
def unpack_stream_item(
item: Any,
lg_modes: list[str],
*,
stream_subgraphs: bool,
) -> tuple[str | None, Any]:
"""Unpack a multi-mode or subgraph stream item into ``(mode, chunk)``."""
if stream_subgraphs:
if isinstance(item, tuple) and len(item) == 3:
_namespace, mode, chunk = item
return str(mode), chunk
if isinstance(item, tuple) and len(item) == 2:
mode, chunk = item
return str(mode), chunk
return None, None
if isinstance(item, tuple) and len(item) == 2:
mode, chunk = item
return str(mode), chunk
return lg_modes[0] if lg_modes else None, item
def should_filter_event(mode: str, chunk: Any) -> bool:
"""Determine whether a stream event should be filtered before publish."""
if mode == "updates" and isinstance(chunk, dict):
node_names = set(chunk.keys())
if node_names & _FILTERED_NODES:
return True
if mode == "messages" and isinstance(chunk, tuple) and len(chunk) == 2:
_, metadata = chunk
if isinstance(metadata, dict):
node = metadata.get("langgraph_node", "")
if node in _FILTERED_NODES:
return True
return False
def external_stream_event_name(mode: str) -> str:
"""Map LangGraph internal modes to the external SSE event contract."""
return mode
@@ -0,0 +1,78 @@
"""Active execution handle management for runs domain."""
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any
from ...types import CancelAction
@dataclass
class RunHandle:
"""In-process control handle for an active run."""
run_id: str
task: asyncio.Task[Any] | None = None
cancel_event: asyncio.Event = field(default_factory=asyncio.Event)
cancel_action: CancelAction = "interrupt"
class RunSupervisor:
"""Own and control active run handles within the current process."""
def __init__(self) -> None:
self._handles: dict[str, RunHandle] = {}
self._lock = asyncio.Lock()
async def launch(
self,
run_id: str,
*,
runner: Callable[[RunHandle], Awaitable[Any]],
) -> RunHandle:
"""Create a handle and start a background task for it."""
handle = RunHandle(run_id=run_id)
async with self._lock:
if run_id in self._handles:
raise RuntimeError(f"Run {run_id} is already active")
self._handles[run_id] = handle
task = asyncio.create_task(runner(handle))
handle.task = task
task.add_done_callback(lambda _: asyncio.create_task(self.cleanup(run_id)))
return handle
async def cancel(
self,
run_id: str,
*,
action: CancelAction = "interrupt",
) -> bool:
"""Signal cancellation for an active handle."""
async with self._lock:
handle = self._handles.get(run_id)
if handle is None:
return False
handle.cancel_action = action
handle.cancel_event.set()
if handle.task is not None and not handle.task.done():
handle.task.cancel()
return True
def get_handle(self, run_id: str) -> RunHandle | None:
"""Return the active handle for a run, if any."""
return self._handles.get(run_id)
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
"""Remove a handle after optional delay."""
if delay > 0:
await asyncio.sleep(delay)
async with self._lock:
self._handles.pop(run_id, None)
@@ -0,0 +1,253 @@
"""In-memory run registry with optional persistent RunStore backing."""
from __future__ import annotations
import asyncio
import logging
import uuid
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Literal
from ..types import RunStatus
logger = logging.getLogger(__name__)
def _now_iso() -> str:
return datetime.now(UTC).isoformat()
@dataclass
class RunRecord:
"""Mutable record for a single run."""
run_id: str
thread_id: str
assistant_id: str | None
status: RunStatus
on_disconnect: Literal["cancel", "continue"]
multitask_strategy: str = "reject"
metadata: dict = field(default_factory=dict)
kwargs: dict = field(default_factory=dict)
created_at: str = ""
updated_at: str = ""
task: asyncio.Task | None = field(default=None, repr=False)
abort_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
abort_action: str = "interrupt"
error: str | None = None
class RunManager:
"""In-memory run registry with optional persistent RunStore backing.
All mutations are protected by an asyncio lock. When a ``store`` is
provided, serializable metadata is also persisted to the store so
that run history survives process restarts.
"""
def __init__(self, store: Any | None = None) -> None:
self._runs: dict[str, RunRecord] = {}
self._lock = asyncio.Lock()
self._store = store
async def _persist_to_store(self, record: RunRecord, *, follow_up_to_run_id: str | None = None) -> None:
"""Best-effort persist run record to backing store."""
if self._store is None:
return
try:
await self._store.put(
record.run_id,
thread_id=record.thread_id,
assistant_id=record.assistant_id,
status=record.status.value,
multitask_strategy=record.multitask_strategy,
metadata=record.metadata or {},
kwargs=record.kwargs or {},
created_at=record.created_at,
follow_up_to_run_id=follow_up_to_run_id,
)
except Exception:
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
async def update_run_completion(self, run_id: str, **kwargs) -> None:
"""Persist token usage and completion data to the backing store."""
if self._store is not None:
try:
await self._store.update_run_completion(run_id, **kwargs)
except Exception:
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
async def create(
self,
thread_id: str,
assistant_id: str | None = None,
*,
on_disconnect: Literal["cancel", "continue"] = "cancel",
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
follow_up_to_run_id: str | None = None,
) -> RunRecord:
"""Create a new pending run and register it."""
run_id = str(uuid.uuid4())
now = _now_iso()
record = RunRecord(
run_id=run_id,
thread_id=thread_id,
assistant_id=assistant_id,
status=RunStatus.pending,
on_disconnect=on_disconnect,
multitask_strategy=multitask_strategy,
metadata=metadata or {},
kwargs=kwargs or {},
created_at=now,
updated_at=now,
)
async with self._lock:
self._runs[run_id] = record
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record
def get(self, run_id: str) -> RunRecord | None:
"""Return a run record by ID, or ``None``."""
return self._runs.get(run_id)
async def list_by_thread(self, thread_id: str) -> list[RunRecord]:
"""Return all runs for a given thread, newest first."""
async with self._lock:
# Dict insertion order matches creation order, so reversing it gives
# us deterministic newest-first results even when timestamps tie.
return [r for r in reversed(self._runs.values()) if r.thread_id == thread_id]
async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
"""Transition a run to a new status."""
async with self._lock:
record = self._runs.get(run_id)
if record is None:
logger.warning("set_status called for unknown run %s", run_id)
return
record.status = status
record.updated_at = _now_iso()
if error is not None:
record.error = error
if self._store is not None:
try:
await self._store.update_status(run_id, status.value, error=error)
except Exception:
logger.warning("Failed to persist status update for run %s", run_id, exc_info=True)
logger.info("Run %s -> %s", run_id, status.value)
async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool:
"""Request cancellation of a run.
Args:
run_id: The run ID to cancel.
action: "interrupt" keeps checkpoint, "rollback" reverts to pre-run state.
Sets the abort event with the action reason and cancels the asyncio task.
Returns ``True`` if the run was in-flight and cancellation was initiated.
"""
async with self._lock:
record = self._runs.get(run_id)
if record is None:
return False
if record.status not in (RunStatus.pending, RunStatus.running):
return False
record.abort_action = action
record.abort_event.set()
if record.task is not None and not record.task.done():
record.task.cancel()
record.status = RunStatus.interrupted
record.updated_at = _now_iso()
logger.info("Run %s cancelled (action=%s)", run_id, action)
return True
async def create_or_reject(
self,
thread_id: str,
assistant_id: str | None = None,
*,
on_disconnect: Literal["cancel", "continue"] = "cancel",
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
follow_up_to_run_id: str | None = None,
) -> RunRecord:
"""Atomically check for inflight runs and create a new one.
For ``reject`` strategy, raises ``ConflictError`` if thread
already has a pending/running run. For ``interrupt``/``rollback``,
cancels inflight runs before creating.
This method holds the lock across both the check and the insert,
eliminating the TOCTOU race in separate ``has_inflight`` + ``create``.
"""
run_id = str(uuid.uuid4())
now = _now_iso()
_supported_strategies = ("reject", "interrupt", "rollback")
async with self._lock:
if multitask_strategy not in _supported_strategies:
raise UnsupportedStrategyError(f"Multitask strategy '{multitask_strategy}' is not yet supported. Supported strategies: {', '.join(_supported_strategies)}")
inflight = [r for r in self._runs.values() if r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running)]
if multitask_strategy == "reject" and inflight:
raise ConflictError(f"Thread {thread_id} already has an active run")
if multitask_strategy in ("interrupt", "rollback") and inflight:
for r in inflight:
r.abort_action = multitask_strategy
r.abort_event.set()
if r.task is not None and not r.task.done():
r.task.cancel()
r.status = RunStatus.interrupted
r.updated_at = now
logger.info(
"Cancelled %d inflight run(s) on thread %s (strategy=%s)",
len(inflight),
thread_id,
multitask_strategy,
)
record = RunRecord(
run_id=run_id,
thread_id=thread_id,
assistant_id=assistant_id,
status=RunStatus.pending,
on_disconnect=on_disconnect,
multitask_strategy=multitask_strategy,
metadata=metadata or {},
kwargs=kwargs or {},
created_at=now,
updated_at=now,
)
self._runs[run_id] = record
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record
async def has_inflight(self, thread_id: str) -> bool:
"""Return ``True`` if *thread_id* has a pending or running run."""
async with self._lock:
return any(r.thread_id == thread_id and r.status in (RunStatus.pending, RunStatus.running) for r in self._runs.values())
async def cleanup(self, run_id: str, *, delay: float = 300) -> None:
"""Remove a run record after an optional delay."""
if delay > 0:
await asyncio.sleep(delay)
async with self._lock:
self._runs.pop(run_id, None)
logger.debug("Run record %s cleaned up", run_id)
class ConflictError(Exception):
"""Raised when multitask_strategy=reject and thread has inflight runs."""
class UnsupportedStrategyError(Exception):
"""Raised when a multitask_strategy value is not yet implemented."""
@@ -0,0 +1,42 @@
"""Execution plan builder for runs domain."""
from __future__ import annotations
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Literal
from ..types import RunRecord, RunSpec
@dataclass(frozen=True)
class ExecutionPlan:
"""Normalized execution inputs derived from a run record and spec."""
record: RunRecord
graph_input: dict[str, Any]
runnable_config: dict[str, Any]
stream_modes: list[str]
stream_subgraphs: bool
interrupt_before: list[str] | Literal["*"] | None
interrupt_after: list[str] | Literal["*"] | None
class ExecutionPlanner:
"""Build executor-ready plans from public run specs."""
def build(self, record: RunRecord, spec: RunSpec) -> ExecutionPlan:
return ExecutionPlan(
record=record,
graph_input=self._normalize_graph_input(spec.input),
runnable_config=deepcopy(spec.runnable_config),
stream_modes=list(spec.stream_modes),
stream_subgraphs=spec.stream_subgraphs,
interrupt_before=spec.interrupt_before,
interrupt_after=spec.interrupt_after,
)
def _normalize_graph_input(self, raw_input: dict[str, Any] | None) -> dict[str, Any]:
if raw_input is None:
return {}
return deepcopy(raw_input)
@@ -0,0 +1,146 @@
"""In-memory run registry for runs domain state."""
from __future__ import annotations
import asyncio
import uuid
from datetime import datetime, timezone
from typing import Any
from ..types import INFLIGHT_STATUSES, RunRecord, RunSpec, RunStatus
class RunRegistry:
"""In-memory source of truth for run records and their status."""
def __init__(self) -> None:
self._records: dict[str, RunRecord] = {}
self._thread_index: dict[str, set[str]] = {} # thread_id -> set[run_id]
self._lock = asyncio.Lock()
async def create(self, spec: RunSpec) -> RunRecord:
"""Create a new RunRecord from RunSpec."""
run_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
record = RunRecord(
run_id=run_id,
thread_id=spec.scope.thread_id,
assistant_id=spec.assistant_id,
status="pending",
temporary=spec.scope.temporary,
multitask_strategy=spec.multitask_strategy,
metadata=dict(spec.metadata),
follow_up_to_run_id=spec.follow_up_to_run_id,
created_at=now,
updated_at=now,
)
async with self._lock:
self._records[run_id] = record
# Update thread index
if spec.scope.thread_id not in self._thread_index:
self._thread_index[spec.scope.thread_id] = set()
self._thread_index[spec.scope.thread_id].add(run_id)
return record
def get(self, run_id: str) -> RunRecord | None:
"""Get RunRecord by run_id."""
return self._records.get(run_id)
async def list_by_thread(self, thread_id: str) -> list[RunRecord]:
"""List all RunRecords for a thread."""
async with self._lock:
run_ids = self._thread_index.get(thread_id, set())
return [self._records[rid] for rid in run_ids if rid in self._records]
async def set_status(
self,
run_id: str,
status: RunStatus,
*,
error: str | None = None,
started_at: str | None = None,
ended_at: str | None = None,
) -> None:
"""Update run status and optional fields."""
async with self._lock:
record = self._records.get(run_id)
if record is None:
return
record.status = status
record.updated_at = datetime.now(timezone.utc).isoformat()
if error is not None:
record.error = error
if started_at is not None:
record.started_at = started_at
if ended_at is not None:
record.ended_at = ended_at
async def has_inflight(self, thread_id: str) -> bool:
"""Check if thread has any inflight runs."""
async with self._lock:
run_ids = self._thread_index.get(thread_id, set())
for rid in run_ids:
record = self._records.get(rid)
if record and record.status in INFLIGHT_STATUSES:
return True
return False
async def interrupt_inflight(self, thread_id: str) -> list[str]:
"""
Mark all inflight runs for a thread as interrupted.
Returns list of interrupted run_ids.
"""
interrupted: list[str] = []
now = datetime.now(timezone.utc).isoformat()
async with self._lock:
run_ids = self._thread_index.get(thread_id, set())
for rid in run_ids:
record = self._records.get(rid)
if record and record.status in INFLIGHT_STATUSES:
record.status = "interrupted"
record.updated_at = now
record.ended_at = now
interrupted.append(rid)
return interrupted
async def update_metadata(self, run_id: str, metadata: dict[str, Any]) -> None:
"""Update run metadata."""
async with self._lock:
record = self._records.get(run_id)
if record is not None:
record.metadata.update(metadata)
record.updated_at = datetime.now(timezone.utc).isoformat()
async def delete(self, run_id: str) -> bool:
"""Delete a run record. Returns True if deleted."""
async with self._lock:
record = self._records.pop(run_id, None)
if record is None:
return False
# Update thread index
thread_runs = self._thread_index.get(record.thread_id)
if thread_runs:
thread_runs.discard(run_id)
return True
def count(self) -> int:
"""Return total number of records."""
return len(self._records)
def count_by_status(self, status: RunStatus) -> int:
"""Return count of records with given status."""
return sum(1 for r in self._records.values() if r.status == status)
# Compatibility alias during the refactor.
RuntimeRunRegistry = RunRegistry
@@ -0,0 +1,76 @@
"""Internal run stream adapter over StreamBridge."""
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from deerflow.runtime.stream_bridge import JSONValue, StreamBridge, StreamEvent
from deerflow.runtime.stream_bridge import StreamStatus
class RunStreamService:
"""Thin runs-domain adapter over the harness stream bridge contract."""
def __init__(self, bridge: "StreamBridge") -> None:
self._bridge = bridge
async def publish_event(
self,
run_id: str,
*,
event: str,
data: "JSONValue",
) -> str:
"""Publish a replayable run event."""
return await self._bridge.publish(run_id, event, data)
async def publish_end(self, run_id: str) -> str:
"""Publish a successful terminal signal."""
return await self._bridge.publish_terminal(run_id, StreamStatus.ENDED)
async def publish_cancelled(
self,
run_id: str,
*,
data: "JSONValue" = None,
) -> str:
"""Publish a cancelled terminal signal."""
return await self._bridge.publish_terminal(
run_id,
StreamStatus.CANCELLED,
data,
)
async def publish_error(
self,
run_id: str,
*,
data: "JSONValue",
) -> str:
"""Publish a failed terminal signal."""
return await self._bridge.publish_terminal(
run_id,
StreamStatus.ERRORED,
data,
)
def subscribe(
self,
run_id: str,
*,
last_event_id: str | None = None,
heartbeat_interval: float = 15.0,
) -> AsyncIterator[StreamEvent]:
"""Subscribe to a run stream with resume support."""
return self._bridge.subscribe(
run_id,
last_event_id=last_event_id,
heartbeat_interval=heartbeat_interval,
)
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
"""Release per-run bridge resources after completion."""
await self._bridge.cleanup(run_id, delay=delay)
@@ -0,0 +1,95 @@
"""Internal run wait helpers based on stream events."""
from __future__ import annotations
from typing import Any
from deerflow.runtime.stream_bridge import StreamEvent
from .streams import RunStreamService
class WaitTimeoutError(TimeoutError):
"""Raised when wait times out."""
pass
class WaitErrorResult:
"""Represents an error result from wait."""
def __init__(self, error: str, details: dict[str, Any] | None = None) -> None:
self.error = error
self.details = details or {}
def to_dict(self) -> dict[str, Any]:
return {"error": self.error, **self.details}
class RunWaitService:
"""
Wait service for runs domain.
Based on RunStreamService.subscribe(), implements wait semantics.
Phase 1 behavior:
- Records last 'values' event
- On 'error', returns unified error structure
- On 'end' only, returns last values
"""
TERMINAL_EVENTS = frozenset({"end", "error", "cancel"})
def __init__(self, stream_service: RunStreamService) -> None:
self._stream_service = stream_service
async def wait_for_terminal(
self,
run_id: str,
*,
last_event_id: str | None = None,
) -> StreamEvent | None:
"""Block until the next terminal event for a run is observed."""
async for event in self._stream_service.subscribe(
run_id,
last_event_id=last_event_id,
):
if event.event in self.TERMINAL_EVENTS:
return event
return None
async def wait_for_values_or_error(
self,
run_id: str,
*,
last_event_id: str | None = None,
) -> dict[str, Any] | WaitErrorResult | None:
"""
Wait for run to complete and return final values or error.
Returns:
- dict: Final values if successful
- WaitErrorResult: If run failed
- None: If no values were produced
"""
last_values: dict[str, Any] | None = None
async for event in self._stream_service.subscribe(
run_id,
last_event_id=last_event_id,
):
if event.event == "values":
last_values = event.data
elif event.event == "error":
return WaitErrorResult(
error=str(event.data) if event.data else "Unknown error",
details={"run_id": run_id},
)
elif event.event in self.TERMINAL_EVENTS:
# Stream ended, return last values
break
return last_values