refactor(runtime): add run DDD boundary skeleton

This commit is contained in:
rayhpeng
2026-06-01 09:22:32 +08:00
parent 9f3be2a9fa
commit 30bb2d5149
24 changed files with 1075 additions and 20 deletions
@@ -0,0 +1,33 @@
"""Run runtime domain model."""
from .errors import InvalidRunTransition, RunDomainError
from .events import RunCancelled, RunCompleted, RunCreated, RunEvent, RunFailed, RunStarted
from .identifiers import AssistantId, RunId, ThreadId, UserId
from .model import Run
from .policies import CancelPolicy, MultitaskDecision, MultitaskPolicy
from .value_objects import CancelAction, DisconnectMode, EventSeq, MultitaskStrategy, RunScope, RunStatus
__all__ = [
"AssistantId",
"CancelAction",
"CancelPolicy",
"DisconnectMode",
"EventSeq",
"InvalidRunTransition",
"MultitaskDecision",
"MultitaskPolicy",
"MultitaskStrategy",
"Run",
"RunCancelled",
"RunCompleted",
"RunCreated",
"RunDomainError",
"RunEvent",
"RunFailed",
"RunId",
"RunScope",
"RunStarted",
"RunStatus",
"ThreadId",
"UserId",
]
@@ -0,0 +1,24 @@
"""Domain-level errors for run lifecycle operations."""
from __future__ import annotations
from .value_objects import RunStatus
class RunDomainError(Exception):
"""Base class for run runtime domain errors."""
class InvalidRunTransition(RunDomainError):
"""Raised when a run status transition violates lifecycle rules."""
def __init__(self, current: RunStatus, target: RunStatus) -> None:
super().__init__(f"Cannot transition run from {current.value!r} to {target.value!r}")
self.current = current
self.target = target
__all__ = [
"InvalidRunTransition",
"RunDomainError",
]
@@ -0,0 +1,64 @@
"""Domain events emitted by the run aggregate."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from deerflow.utils.time import now_iso
from .identifiers import AssistantId, RunId, ThreadId
from .value_objects import CancelAction, RunStatus
@dataclass(frozen=True)
class RunCreated:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
assistant_id: AssistantId | None = None
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass(frozen=True)
class RunStarted:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
@dataclass(frozen=True)
class RunCompleted:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
@dataclass(frozen=True)
class RunFailed:
run_id: RunId
thread_id: ThreadId
status: RunStatus
occurred_at: str = field(default_factory=now_iso)
error: str | None = None
@dataclass(frozen=True)
class RunCancelled:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
action: CancelAction = CancelAction.interrupt
RunEvent = RunCreated | RunStarted | RunCompleted | RunFailed | RunCancelled
__all__ = [
"RunCancelled",
"RunCompleted",
"RunCreated",
"RunEvent",
"RunFailed",
"RunStarted",
]
@@ -0,0 +1,27 @@
"""Lightweight identifiers for the run runtime domain."""
from __future__ import annotations
from typing import NewType
RunId = NewType("RunId", str)
ThreadId = NewType("ThreadId", str)
AssistantId = NewType("AssistantId", str)
UserId = NewType("UserId", str)
def require_non_empty(value: str, *, field_name: str) -> str:
"""Return a stripped identifier value, rejecting empty identifiers."""
normalized = value.strip()
if not normalized:
raise ValueError(f"{field_name} must not be empty")
return normalized
__all__ = [
"AssistantId",
"RunId",
"ThreadId",
"UserId",
"require_non_empty",
]
@@ -0,0 +1,193 @@
"""Run aggregate root and lifecycle invariants."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from deerflow.utils.time import now_iso
from .errors import InvalidRunTransition
from .events import RunCancelled, RunCompleted, RunCreated, RunEvent, RunFailed, RunStarted
from .identifiers import AssistantId, RunId, ThreadId, require_non_empty
from .value_objects import CancelAction, MultitaskStrategy, RunScope, RunStatus
# Keep lifecycle transitions explicit so later application code cannot invent
# ad hoc status moves outside the aggregate.
_ALLOWED_TRANSITIONS: dict[RunStatus, frozenset[RunStatus]] = {
RunStatus.pending: frozenset(
{
RunStatus.running,
RunStatus.error,
RunStatus.timeout,
RunStatus.interrupted,
}
),
RunStatus.running: frozenset(
{
RunStatus.success,
RunStatus.error,
RunStatus.timeout,
RunStatus.interrupted,
}
),
RunStatus.success: frozenset(),
RunStatus.error: frozenset(),
RunStatus.timeout: frozenset(),
RunStatus.interrupted: frozenset(),
}
@dataclass
class Run:
"""Run aggregate root.
The aggregate owns lifecycle invariants only. Infrastructure concerns such
as SQL sessions, SSE frames, Redis clients, and FastAPI requests stay out of
this model.
"""
run_id: RunId
thread_id: ThreadId
status: RunStatus
assistant_id: AssistantId | None = None
scope: RunScope = RunScope.stateful
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject
metadata: dict[str, Any] = field(default_factory=dict)
kwargs: dict[str, Any] = field(default_factory=dict)
created_at: str = field(default_factory=now_iso)
updated_at: str = field(default_factory=now_iso)
error: str | None = None
model_name: str | None = None
_pending_events: list[RunEvent] = field(default_factory=list, init=False, repr=False)
def __post_init__(self) -> None:
self.run_id = RunId(require_non_empty(str(self.run_id), field_name="run_id"))
self.thread_id = ThreadId(require_non_empty(str(self.thread_id), field_name="thread_id"))
if self.assistant_id is not None:
self.assistant_id = AssistantId(require_non_empty(str(self.assistant_id), field_name="assistant_id"))
@classmethod
def create(
cls,
*,
run_id: RunId,
thread_id: ThreadId,
assistant_id: AssistantId | None = None,
scope: RunScope = RunScope.stateful,
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject,
metadata: dict[str, Any] | None = None,
kwargs: dict[str, Any] | None = None,
model_name: str | None = None,
created_at: str | None = None,
) -> Run:
timestamp = created_at or now_iso()
run = cls(
run_id=run_id,
thread_id=thread_id,
assistant_id=assistant_id,
status=RunStatus.pending,
scope=scope,
multitask_strategy=multitask_strategy,
metadata=metadata or {},
kwargs=kwargs or {},
created_at=timestamp,
updated_at=timestamp,
model_name=model_name,
)
run._record_event(
RunCreated(
run_id=run.run_id,
thread_id=run.thread_id,
occurred_at=timestamp,
assistant_id=run.assistant_id,
metadata=dict(run.metadata),
)
)
return run
@property
def is_terminal(self) -> bool:
return not _ALLOWED_TRANSITIONS[self.status]
def pull_events(self) -> tuple[RunEvent, ...]:
# Domain events are drained by the application layer after the aggregate
# has accepted a state change.
events = tuple(self._pending_events)
self._pending_events.clear()
return events
def mark_started(self, *, at: str | None = None) -> None:
self._transition_to(RunStatus.running, at=at)
def mark_completed(self, *, at: str | None = None) -> None:
self._transition_to(RunStatus.success, at=at)
def mark_failed(self, error: str | None = None, *, at: str | None = None) -> None:
self._transition_to(RunStatus.error, error=error, at=at)
def mark_timed_out(self, error: str | None = None, *, at: str | None = None) -> None:
self._transition_to(RunStatus.timeout, error=error, at=at)
def mark_cancelled(self, *, action: CancelAction = CancelAction.interrupt, at: str | None = None) -> None:
self._transition_to(RunStatus.interrupted, action=action, at=at)
def _transition_to(
self,
target: RunStatus,
*,
error: str | None = None,
action: CancelAction = CancelAction.interrupt,
at: str | None = None,
) -> None:
if target == self.status:
return
if target not in _ALLOWED_TRANSITIONS[self.status]:
raise InvalidRunTransition(self.status, target)
timestamp = at or now_iso()
self.status = target
self.updated_at = timestamp
if error is not None:
self.error = error
self._record_event(self._event_for_transition(target, timestamp, error=error, action=action))
def _event_for_transition(
self,
target: RunStatus,
occurred_at: str,
*,
error: str | None,
action: CancelAction,
) -> RunEvent:
# Keep event construction next to the transition rules so a new status
# cannot be added without an explicit durable event shape.
if target == RunStatus.running:
return RunStarted(run_id=self.run_id, thread_id=self.thread_id, occurred_at=occurred_at)
if target == RunStatus.success:
return RunCompleted(run_id=self.run_id, thread_id=self.thread_id, occurred_at=occurred_at)
if target in (RunStatus.error, RunStatus.timeout):
return RunFailed(
run_id=self.run_id,
thread_id=self.thread_id,
status=target,
occurred_at=occurred_at,
error=error,
)
if target == RunStatus.interrupted:
return RunCancelled(
run_id=self.run_id,
thread_id=self.thread_id,
occurred_at=occurred_at,
action=action,
)
raise InvalidRunTransition(self.status, target)
def _record_event(self, event: RunEvent) -> None:
self._pending_events.append(event)
__all__ = [
"Run",
"RunStatus",
]
@@ -0,0 +1,50 @@
"""Domain policies for run concurrency and cancellation."""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from enum import StrEnum
from .model import Run
from .value_objects import CancelAction, MultitaskStrategy, RunStatus
class MultitaskDecision(StrEnum):
"""Application-level decision produced by a multitask policy."""
allow = "allow"
reject = "reject"
cancel_existing = "cancel_existing"
enqueue = "enqueue"
@dataclass(frozen=True)
class MultitaskPolicy:
strategy: MultitaskStrategy = MultitaskStrategy.reject
def decide(self, active_runs: Sequence[Run]) -> MultitaskDecision:
inflight = [run for run in active_runs if run.status in (RunStatus.pending, RunStatus.running)]
if not inflight:
return MultitaskDecision.allow
if self.strategy == MultitaskStrategy.reject:
return MultitaskDecision.reject
if self.strategy in (MultitaskStrategy.interrupt, MultitaskStrategy.rollback):
return MultitaskDecision.cancel_existing
return MultitaskDecision.enqueue
@dataclass(frozen=True)
class CancelPolicy:
action: CancelAction = CancelAction.interrupt
@property
def rolls_back_checkpoint(self) -> bool:
return self.action == CancelAction.rollback
__all__ = [
"CancelPolicy",
"MultitaskDecision",
"MultitaskPolicy",
]
@@ -0,0 +1,88 @@
"""Domain value objects for run lifecycle semantics."""
from __future__ import annotations
from dataclasses import dataclass
from enum import StrEnum
class RunStatus(StrEnum):
"""Lifecycle status of a single run."""
pending = "pending"
running = "running"
success = "success"
error = "error"
timeout = "timeout"
interrupted = "interrupted"
class DisconnectMode(StrEnum):
"""Behaviour when the SSE consumer disconnects."""
cancel = "cancel"
continue_ = "continue"
class RunScope(StrEnum):
"""Conversation scope for a run."""
stateful = "stateful"
stateless = "stateless"
temporary_thread = "temporary_thread"
class MultitaskStrategy(StrEnum):
"""Concurrency strategy for a new run on a thread."""
reject = "reject"
interrupt = "interrupt"
rollback = "rollback"
enqueue = "enqueue"
class CancelAction(StrEnum):
"""Cancellation action requested by an API or supervisor."""
interrupt = "interrupt"
rollback = "rollback"
TERMINAL_RUN_STATUSES: frozenset[RunStatus] = frozenset(
{
RunStatus.success,
RunStatus.error,
RunStatus.timeout,
RunStatus.interrupted,
}
)
def is_terminal_status(status: RunStatus) -> bool:
return status in TERMINAL_RUN_STATUSES
@dataclass(frozen=True, order=True)
class EventSeq:
"""Thread-local event sequence number."""
value: int
def __post_init__(self) -> None:
if self.value < 0:
raise ValueError("EventSeq must be non-negative")
def next(self) -> EventSeq:
return EventSeq(self.value + 1)
__all__ = [
"CancelAction",
"DisconnectMode",
"EventSeq",
"MultitaskStrategy",
"RunScope",
"RunStatus",
"TERMINAL_RUN_STATUSES",
"is_terminal_status",
]