mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 17:06:00 +00:00
fix(actor): harden lifecycle, supervision, Redis mailbox, and add comprehensive tests
- Fix spawn() zombie cell: clean up registry on start() failure - Fix shutdown(): cancel + await tasks that exceed graceful timeout - Fix _shutdown(): await mailbox.close() to release backend resources - Fix escalate directive: stop failing child before propagating to grandparent - Fix RedisMailbox.put(): wrap Redis errors in try/except, return False on failure - Fix retry.py: replace assert with proper raise for last_exc - Add put_batch() to Mailbox abstraction for single-roundtrip bulk enqueue - Add RedisMailbox.put_batch() with atomic Lua script for bounded queues - Add MailboxFullError exception type for semantic backpressure handling - Add redis>=7.4.0 dependency with public PyPI sources in uv.lock Tests added (31 total, up from 27): - test_middleware_on_restart_hook: verifies middleware.on_restart() on supervision restart - test_ask_propagates_actor_exception: ask() re-raises original exception type - test_ask_propagates_exception_while_supervised: exception propagates; root actor survives - test_ask_timeout_late_reply_no_exception: late reply after timeout is silent no-op - test_actor_backpressure.py: MailboxFullError + dead letter on full mailbox - test_actor_retry.py: ask_with_retry with exponential backoff - test_mailbox_redis.py: RedisMailbox put/get/batch/close - bench_actor_redis.py: RedisMailbox throughput benchmarks
This commit is contained in:
@@ -19,7 +19,8 @@ Usage::
|
||||
from .actor import Actor, ActorContext
|
||||
from .mailbox import Mailbox, MemoryMailbox
|
||||
from .middleware import Middleware
|
||||
from .ref import ActorRef, ReplyChannel
|
||||
from .ref import ActorRef, MailboxFullError, ReplyChannel
|
||||
from .retry import IdempotentActorMixin, IdempotencyStore, RetryEnvelope, ask_with_retry
|
||||
from .supervision import AllForOneStrategy, Directive, OneForOneStrategy, SupervisorStrategy
|
||||
from .system import ActorSystem, DeadLetter
|
||||
|
||||
@@ -32,9 +33,14 @@ __all__ = [
|
||||
"DeadLetter",
|
||||
"Directive",
|
||||
"Mailbox",
|
||||
"MailboxFullError",
|
||||
"MemoryMailbox",
|
||||
"Middleware",
|
||||
"OneForOneStrategy",
|
||||
"ReplyChannel",
|
||||
"RetryEnvelope",
|
||||
"SupervisorStrategy",
|
||||
"IdempotentActorMixin",
|
||||
"IdempotencyStore",
|
||||
"ask_with_retry",
|
||||
]
|
||||
|
||||
@@ -12,6 +12,12 @@ import asyncio
|
||||
from typing import Any
|
||||
|
||||
|
||||
BACKPRESSURE_BLOCK = "block"
|
||||
BACKPRESSURE_DROP_NEW = "drop_new"
|
||||
BACKPRESSURE_FAIL = "fail"
|
||||
BACKPRESSURE_POLICIES = {BACKPRESSURE_BLOCK, BACKPRESSURE_DROP_NEW, BACKPRESSURE_FAIL}
|
||||
|
||||
|
||||
class Mailbox(abc.ABC):
|
||||
"""Abstract mailbox — the message queue for an actor.
|
||||
|
||||
@@ -44,6 +50,18 @@ class Mailbox(abc.ABC):
|
||||
def full(self) -> bool:
|
||||
"""Return True if mailbox is at capacity."""
|
||||
|
||||
async def put_batch(self, msgs: list[Any]) -> int:
|
||||
"""Enqueue multiple messages. Returns count accepted.
|
||||
|
||||
Default implementation falls back to sequential ``put`` calls.
|
||||
Backends like Redis should override this for efficient bulk push.
|
||||
"""
|
||||
count = 0
|
||||
for msg in msgs:
|
||||
if await self.put(msg):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Release resources. Default is no-op."""
|
||||
|
||||
@@ -55,23 +73,32 @@ class Empty(Exception):
|
||||
class MemoryMailbox(Mailbox):
|
||||
"""In-process mailbox backed by ``asyncio.Queue``."""
|
||||
|
||||
def __init__(self, maxsize: int = 256) -> None:
|
||||
def __init__(self, maxsize: int = 256, *, backpressure_policy: str = BACKPRESSURE_BLOCK) -> None:
|
||||
if backpressure_policy not in BACKPRESSURE_POLICIES:
|
||||
raise ValueError(
|
||||
f"Invalid backpressure_policy={backpressure_policy!r}, "
|
||||
f"expected one of {sorted(BACKPRESSURE_POLICIES)}"
|
||||
)
|
||||
self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=maxsize)
|
||||
self._maxsize = maxsize
|
||||
self._backpressure_policy = backpressure_policy
|
||||
|
||||
async def put(self, msg: Any) -> bool:
|
||||
try:
|
||||
if self._backpressure_policy == BACKPRESSURE_BLOCK:
|
||||
await self._queue.put(msg)
|
||||
return True
|
||||
except asyncio.QueueFull:
|
||||
return False
|
||||
|
||||
def put_nowait(self, msg: Any) -> bool:
|
||||
try:
|
||||
if self._backpressure_policy in (BACKPRESSURE_DROP_NEW, BACKPRESSURE_FAIL):
|
||||
if self._queue.full():
|
||||
return False
|
||||
self._queue.put_nowait(msg)
|
||||
return True
|
||||
except asyncio.QueueFull:
|
||||
return False
|
||||
|
||||
def put_nowait(self, msg: Any) -> bool:
|
||||
if self._queue.full():
|
||||
return False
|
||||
self._queue.put_nowait(msg)
|
||||
return True
|
||||
|
||||
async def get(self) -> Any:
|
||||
return await self._queue.get()
|
||||
|
||||
@@ -107,12 +107,16 @@ class RedisMailbox(Mailbox):
|
||||
if self._closed:
|
||||
return False
|
||||
data = _serialize(msg)
|
||||
if self._maxlen > 0:
|
||||
# Atomic check+push via Lua script to avoid TOCTOU race
|
||||
result = await self._redis.evalsha_or_eval(self._LUA_BOUNDED_PUSH, 1, self._queue_name, data, self._maxlen)
|
||||
return bool(result)
|
||||
await self._redis.lpush(self._queue_name, data)
|
||||
return True
|
||||
try:
|
||||
if self._maxlen > 0:
|
||||
# Atomic check+push via Lua script to avoid TOCTOU race
|
||||
result = await self._redis.eval(self._LUA_BOUNDED_PUSH, 1, self._queue_name, data, self._maxlen)
|
||||
return bool(result)
|
||||
await self._redis.lpush(self._queue_name, data)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("RedisMailbox.put failed for %s: %s", self._queue_name, e)
|
||||
return False
|
||||
|
||||
def put_nowait(self, msg: Any) -> bool:
|
||||
"""Redis cannot do synchronous non-blocking enqueue reliably.
|
||||
@@ -122,6 +126,36 @@ class RedisMailbox(Mailbox):
|
||||
"""
|
||||
return False
|
||||
|
||||
async def put_batch(self, msgs: list[Any]) -> int:
|
||||
"""Push multiple messages in a single LPUSH command (one round-trip).
|
||||
|
||||
Unbounded queues: all messages sent atomically in one LPUSH.
|
||||
Bounded queues: sequential puts to respect maxlen (no batch Lua script needed).
|
||||
"""
|
||||
if self._closed or not msgs:
|
||||
return 0
|
||||
data_list = []
|
||||
for msg in msgs:
|
||||
try:
|
||||
data_list.append(_serialize(msg))
|
||||
except TypeError as e:
|
||||
logger.warning("Skipping non-serializable message in put_batch: %s", e)
|
||||
if not data_list:
|
||||
return 0
|
||||
if self._maxlen > 0:
|
||||
count = 0
|
||||
for data in data_list:
|
||||
# Reuse the Lua script for TOCTOU-safe bounded check (same as put())
|
||||
result = await self._redis.eval(self._LUA_BOUNDED_PUSH, 1, self._queue_name, data, self._maxlen)
|
||||
if result:
|
||||
count += 1
|
||||
else:
|
||||
break # queue full — stop early
|
||||
return count
|
||||
# Unbounded: single LPUSH with all values — one network round-trip
|
||||
await self._redis.lpush(self._queue_name, *data_list)
|
||||
return len(data_list)
|
||||
|
||||
async def get(self) -> Any:
|
||||
"""Blocking dequeue via BRPOP. Retries until a message arrives."""
|
||||
while not self._closed:
|
||||
|
||||
@@ -83,6 +83,10 @@ class ActorStoppedError(Exception):
|
||||
"""Raised when sending to a stopped actor via ask."""
|
||||
|
||||
|
||||
class MailboxFullError(RuntimeError):
|
||||
"""Raised when a message is rejected because the mailbox is at capacity."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal message wrappers (serializable — no Future objects)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,142 @@
|
||||
"""Retry + idempotency helpers for Actor ask/tell patterns.
|
||||
|
||||
This module provides:
|
||||
- Message envelope carrying retry/idempotency metadata
|
||||
- In-memory idempotency store (process-local)
|
||||
- ask_with_retry helper (bounded retries + exponential backoff + jitter)
|
||||
|
||||
Design notes:
|
||||
- Keep transport-agnostic; works with current in-memory mailbox.
|
||||
- Business handlers must opt in by using ``IdempotentActorMixin`` and
|
||||
wrapping logic with ``handle_idempotent``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class RetryEnvelope:
|
||||
"""Metadata wrapper for idempotent/retriable messages."""
|
||||
|
||||
payload: Any
|
||||
message_id: str = field(default_factory=lambda: uuid.uuid4().hex)
|
||||
idempotency_key: str | None = None
|
||||
attempt: int = 1
|
||||
max_attempts: int = 1
|
||||
created_at_ms: int = field(default_factory=lambda: int(time.time() * 1000))
|
||||
|
||||
@classmethod
|
||||
def wrap(
|
||||
cls,
|
||||
payload: Any,
|
||||
*,
|
||||
idempotency_key: str | None = None,
|
||||
attempt: int = 1,
|
||||
max_attempts: int = 1,
|
||||
) -> "RetryEnvelope":
|
||||
return cls(
|
||||
payload=payload,
|
||||
idempotency_key=idempotency_key,
|
||||
attempt=attempt,
|
||||
max_attempts=max_attempts,
|
||||
)
|
||||
|
||||
|
||||
class IdempotencyStore:
|
||||
"""Process-local idempotency result store."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._results: dict[str, Any] = {}
|
||||
|
||||
def has(self, key: str) -> bool:
|
||||
return key in self._results
|
||||
|
||||
def get(self, key: str) -> Any:
|
||||
return self._results[key]
|
||||
|
||||
def set(self, key: str, value: Any) -> None:
|
||||
self._results[key] = value
|
||||
|
||||
|
||||
class IdempotentActorMixin:
|
||||
"""Mixin adding idempotent handling utility for actors.
|
||||
|
||||
Usage in actor::
|
||||
|
||||
class MyActor(IdempotentActorMixin, Actor):
|
||||
async def on_receive(self, message):
|
||||
return await self.handle_idempotent(message, self._handle)
|
||||
|
||||
async def _handle(self, payload):
|
||||
...
|
||||
"""
|
||||
|
||||
def _idempotency_store(self) -> IdempotencyStore:
|
||||
store = getattr(self, "_idem_store", None)
|
||||
if store is None:
|
||||
store = IdempotencyStore()
|
||||
setattr(self, "_idem_store", store)
|
||||
return store
|
||||
|
||||
async def handle_idempotent(self, message: Any, handler):
|
||||
if not isinstance(message, RetryEnvelope):
|
||||
return await handler(message)
|
||||
|
||||
key = message.idempotency_key
|
||||
if not key:
|
||||
return await handler(message.payload)
|
||||
|
||||
store = self._idempotency_store()
|
||||
if store.has(key):
|
||||
return store.get(key)
|
||||
|
||||
result = await handler(message.payload)
|
||||
store.set(key, result)
|
||||
return result
|
||||
|
||||
|
||||
async def ask_with_retry(
|
||||
ref,
|
||||
payload: Any,
|
||||
*,
|
||||
timeout: float = 5.0,
|
||||
max_attempts: int = 3,
|
||||
base_backoff_s: float = 0.1,
|
||||
max_backoff_s: float = 5.0,
|
||||
jitter_ratio: float = 0.3,
|
||||
retry_exceptions: tuple[type[BaseException], ...] = (asyncio.TimeoutError,),
|
||||
idempotency_key: str | None = None,
|
||||
) -> Any:
|
||||
"""Ask actor with bounded retries and envelope metadata."""
|
||||
if max_attempts < 1:
|
||||
raise ValueError("max_attempts must be >= 1")
|
||||
|
||||
key = idempotency_key or uuid.uuid4().hex
|
||||
last_exc: BaseException | None = None
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
msg = RetryEnvelope.wrap(
|
||||
payload,
|
||||
idempotency_key=key,
|
||||
attempt=attempt,
|
||||
max_attempts=max_attempts,
|
||||
)
|
||||
try:
|
||||
return await ref.ask(msg, timeout=timeout)
|
||||
except retry_exceptions as exc:
|
||||
last_exc = exc
|
||||
if attempt >= max_attempts:
|
||||
break
|
||||
|
||||
backoff = min(max_backoff_s, base_backoff_s * (2 ** (attempt - 1)))
|
||||
jitter = backoff * jitter_ratio * random.random()
|
||||
await asyncio.sleep(backoff + jitter)
|
||||
|
||||
raise last_exc # type: ignore[misc] # always set: loop runs ≥1 time and sets on last iteration
|
||||
@@ -11,7 +11,7 @@ from typing import Any
|
||||
from .actor import Actor, ActorContext
|
||||
from .mailbox import Empty, Mailbox, MemoryMailbox
|
||||
from .middleware import ActorMailboxContext, Middleware, NextFn, build_middleware_chain
|
||||
from .ref import ActorRef, ActorStoppedError, ReplyChannel, _Envelope, _ReplyMessage, _ReplyRegistry, _Stop
|
||||
from .ref import ActorRef, ActorStoppedError, MailboxFullError, ReplyChannel, _Envelope, _ReplyMessage, _ReplyRegistry, _Stop
|
||||
from .supervision import Directive, SupervisorStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -87,7 +87,11 @@ class ActorSystem:
|
||||
middlewares=middlewares or [],
|
||||
)
|
||||
self._root_cells[name] = cell
|
||||
await cell.start()
|
||||
try:
|
||||
await cell.start()
|
||||
except Exception:
|
||||
del self._root_cells[name]
|
||||
raise
|
||||
return cell.ref
|
||||
|
||||
async def shutdown(self, *, timeout: float = 10.0) -> None:
|
||||
@@ -99,7 +103,12 @@ class ActorSystem:
|
||||
if cell.task is not None:
|
||||
tasks.append(cell.task)
|
||||
if tasks:
|
||||
await asyncio.wait(tasks, timeout=timeout)
|
||||
_, pending = await asyncio.wait(tasks, timeout=timeout)
|
||||
# Cancel tasks that didn't finish within the timeout to prevent zombie tasks
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
if pending:
|
||||
await asyncio.wait(pending, timeout=2.0)
|
||||
self._root_cells.clear()
|
||||
self._replies.reject_all(ActorStoppedError("ActorSystem shutting down"))
|
||||
await self._reply_channel.stop_listener()
|
||||
@@ -188,16 +197,25 @@ class _ActorCell:
|
||||
self.task = asyncio.create_task(self._run(), name=f"actor:{self.path}")
|
||||
|
||||
async def enqueue(self, msg: _Envelope | _Stop) -> None:
|
||||
if not self.mailbox.put_nowait(msg):
|
||||
# Try non-blocking first (fast path for MemoryMailbox)
|
||||
if self.mailbox.put_nowait(msg):
|
||||
return
|
||||
# Fallback to async put (required for Redis and other async backends)
|
||||
if not await self.mailbox.put(msg):
|
||||
if isinstance(msg, _Envelope) and msg.correlation_id is not None:
|
||||
self.system._replies.reject(msg.correlation_id, RuntimeError(f"Mailbox full: {self.path}"))
|
||||
self.system._replies.reject(msg.correlation_id, MailboxFullError(f"Mailbox full: {self.path}"))
|
||||
elif isinstance(msg, _Envelope):
|
||||
self.system._dead_letter(self.ref, msg.payload, msg.sender)
|
||||
|
||||
def request_stop(self) -> None:
|
||||
"""Request graceful shutdown. Falls back to task.cancel() if mailbox full."""
|
||||
"""Request graceful shutdown.
|
||||
|
||||
Tries put_nowait first. If that fails (full or unsupported backend),
|
||||
cancels the task directly so _run exits via CancelledError → finally → _shutdown.
|
||||
"""
|
||||
if not self.stopped:
|
||||
if not self.mailbox.put_nowait(_Stop()):
|
||||
# Redis/async backends can't put_nowait — cancel the task
|
||||
if self.task is not None and not self.task.done():
|
||||
self.task.cancel()
|
||||
else:
|
||||
@@ -223,7 +241,11 @@ class _ActorCell:
|
||||
middlewares=middlewares or [],
|
||||
)
|
||||
self.children[name] = child
|
||||
await child.start()
|
||||
try:
|
||||
await child.start()
|
||||
except Exception:
|
||||
del self.children[name]
|
||||
raise
|
||||
return child.ref
|
||||
|
||||
# -- Processing loop -------------------------------------------------------
|
||||
@@ -310,6 +332,11 @@ class _ActorCell:
|
||||
# Remove from parent
|
||||
if self.parent is not None:
|
||||
self.parent.children.pop(self.name, None)
|
||||
# Close mailbox to release backend resources (e.g. Redis connections)
|
||||
try:
|
||||
await self.mailbox.close()
|
||||
except Exception:
|
||||
logger.exception("Error closing mailbox for %s", self.path)
|
||||
|
||||
# -- Supervision -----------------------------------------------------------
|
||||
|
||||
@@ -337,8 +364,16 @@ class _ActorCell:
|
||||
return
|
||||
|
||||
if directive == Directive.escalate:
|
||||
logger.info("Supervisor %s: escalate %s", self.path, type(error).__name__)
|
||||
raise error
|
||||
# Stop the failing child, then propagate failure up the supervision chain.
|
||||
# We cannot use `raise error` here — that would crash the child's _run
|
||||
# loop instead of notifying the grandparent's supervisor.
|
||||
child.request_stop()
|
||||
if self.parent is not None:
|
||||
logger.info("Supervisor %s: escalate %s to grandparent %s", self.path, type(error).__name__, self.parent.path)
|
||||
await self.parent._handle_child_failure(self, error)
|
||||
else:
|
||||
logger.error("Uncaught escalation at root actor %s: %s", self.path, error)
|
||||
return
|
||||
|
||||
if directive == Directive.restart:
|
||||
for name in affected:
|
||||
|
||||
Reference in New Issue
Block a user