mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 01:15:58 +00:00
194 lines
6.6 KiB
Python
194 lines
6.6 KiB
Python
"""Run aggregate root and lifecycle invariants."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from deerflow.utils.time import now_iso
|
|
|
|
from .errors import InvalidRunTransition
|
|
from .events import RunCancelled, RunCompleted, RunCreated, RunEvent, RunFailed, RunStarted
|
|
from .identifiers import AssistantId, RunId, ThreadId, require_non_empty
|
|
from .value_objects import CancelAction, MultitaskStrategy, RunScope, RunStatus
|
|
|
|
# Keep lifecycle transitions explicit so later application code cannot invent
|
|
# ad hoc status moves outside the aggregate.
|
|
_ALLOWED_TRANSITIONS: dict[RunStatus, frozenset[RunStatus]] = {
|
|
RunStatus.pending: frozenset(
|
|
{
|
|
RunStatus.running,
|
|
RunStatus.error,
|
|
RunStatus.timeout,
|
|
RunStatus.interrupted,
|
|
}
|
|
),
|
|
RunStatus.running: frozenset(
|
|
{
|
|
RunStatus.success,
|
|
RunStatus.error,
|
|
RunStatus.timeout,
|
|
RunStatus.interrupted,
|
|
}
|
|
),
|
|
RunStatus.success: frozenset(),
|
|
RunStatus.error: frozenset(),
|
|
RunStatus.timeout: frozenset(),
|
|
RunStatus.interrupted: frozenset(),
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class Run:
|
|
"""Run aggregate root.
|
|
|
|
The aggregate owns lifecycle invariants only. Infrastructure concerns such
|
|
as SQL sessions, SSE frames, Redis clients, and FastAPI requests stay out of
|
|
this model.
|
|
"""
|
|
|
|
run_id: RunId
|
|
thread_id: ThreadId
|
|
status: RunStatus
|
|
assistant_id: AssistantId | None = None
|
|
scope: RunScope = RunScope.stateful
|
|
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
created_at: str = field(default_factory=now_iso)
|
|
updated_at: str = field(default_factory=now_iso)
|
|
error: str | None = None
|
|
model_name: str | None = None
|
|
_pending_events: list[RunEvent] = field(default_factory=list, init=False, repr=False)
|
|
|
|
def __post_init__(self) -> None:
|
|
self.run_id = RunId(require_non_empty(str(self.run_id), field_name="run_id"))
|
|
self.thread_id = ThreadId(require_non_empty(str(self.thread_id), field_name="thread_id"))
|
|
if self.assistant_id is not None:
|
|
self.assistant_id = AssistantId(require_non_empty(str(self.assistant_id), field_name="assistant_id"))
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
*,
|
|
run_id: RunId,
|
|
thread_id: ThreadId,
|
|
assistant_id: AssistantId | None = None,
|
|
scope: RunScope = RunScope.stateful,
|
|
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject,
|
|
metadata: dict[str, Any] | None = None,
|
|
kwargs: dict[str, Any] | None = None,
|
|
model_name: str | None = None,
|
|
created_at: str | None = None,
|
|
) -> Run:
|
|
timestamp = created_at or now_iso()
|
|
run = cls(
|
|
run_id=run_id,
|
|
thread_id=thread_id,
|
|
assistant_id=assistant_id,
|
|
status=RunStatus.pending,
|
|
scope=scope,
|
|
multitask_strategy=multitask_strategy,
|
|
metadata=metadata or {},
|
|
kwargs=kwargs or {},
|
|
created_at=timestamp,
|
|
updated_at=timestamp,
|
|
model_name=model_name,
|
|
)
|
|
run._record_event(
|
|
RunCreated(
|
|
run_id=run.run_id,
|
|
thread_id=run.thread_id,
|
|
occurred_at=timestamp,
|
|
assistant_id=run.assistant_id,
|
|
metadata=dict(run.metadata),
|
|
)
|
|
)
|
|
return run
|
|
|
|
@property
|
|
def is_terminal(self) -> bool:
|
|
return not _ALLOWED_TRANSITIONS[self.status]
|
|
|
|
def pull_events(self) -> tuple[RunEvent, ...]:
|
|
# Domain events are drained by the application layer after the aggregate
|
|
# has accepted a state change.
|
|
events = tuple(self._pending_events)
|
|
self._pending_events.clear()
|
|
return events
|
|
|
|
def mark_started(self, *, at: str | None = None) -> None:
|
|
self._transition_to(RunStatus.running, at=at)
|
|
|
|
def mark_completed(self, *, at: str | None = None) -> None:
|
|
self._transition_to(RunStatus.success, at=at)
|
|
|
|
def mark_failed(self, error: str | None = None, *, at: str | None = None) -> None:
|
|
self._transition_to(RunStatus.error, error=error, at=at)
|
|
|
|
def mark_timed_out(self, error: str | None = None, *, at: str | None = None) -> None:
|
|
self._transition_to(RunStatus.timeout, error=error, at=at)
|
|
|
|
def mark_cancelled(self, *, action: CancelAction = CancelAction.interrupt, at: str | None = None) -> None:
|
|
self._transition_to(RunStatus.interrupted, action=action, at=at)
|
|
|
|
def _transition_to(
|
|
self,
|
|
target: RunStatus,
|
|
*,
|
|
error: str | None = None,
|
|
action: CancelAction = CancelAction.interrupt,
|
|
at: str | None = None,
|
|
) -> None:
|
|
if target == self.status:
|
|
return
|
|
if target not in _ALLOWED_TRANSITIONS[self.status]:
|
|
raise InvalidRunTransition(self.status, target)
|
|
|
|
timestamp = at or now_iso()
|
|
self.status = target
|
|
self.updated_at = timestamp
|
|
if error is not None:
|
|
self.error = error
|
|
self._record_event(self._event_for_transition(target, timestamp, error=error, action=action))
|
|
|
|
def _event_for_transition(
|
|
self,
|
|
target: RunStatus,
|
|
occurred_at: str,
|
|
*,
|
|
error: str | None,
|
|
action: CancelAction,
|
|
) -> RunEvent:
|
|
# Keep event construction next to the transition rules so a new status
|
|
# cannot be added without an explicit durable event shape.
|
|
if target == RunStatus.running:
|
|
return RunStarted(run_id=self.run_id, thread_id=self.thread_id, occurred_at=occurred_at)
|
|
if target == RunStatus.success:
|
|
return RunCompleted(run_id=self.run_id, thread_id=self.thread_id, occurred_at=occurred_at)
|
|
if target in (RunStatus.error, RunStatus.timeout):
|
|
return RunFailed(
|
|
run_id=self.run_id,
|
|
thread_id=self.thread_id,
|
|
status=target,
|
|
occurred_at=occurred_at,
|
|
error=error,
|
|
)
|
|
if target == RunStatus.interrupted:
|
|
return RunCancelled(
|
|
run_id=self.run_id,
|
|
thread_id=self.thread_id,
|
|
occurred_at=occurred_at,
|
|
action=action,
|
|
)
|
|
raise InvalidRunTransition(self.status, target)
|
|
|
|
def _record_event(self, event: RunEvent) -> None:
|
|
self._pending_events.append(event)
|
|
|
|
|
|
__all__ = [
|
|
"Run",
|
|
"RunStatus",
|
|
]
|