Files
deer-flow/backend/tests/test_actor.py
T
greatmengqi 3e17417122 feat: asyncio-native Actor framework with supervision, middleware, and pluggable mailbox
Lightweight actor library built on asyncio primitives (~800 lines):

- Actor base class with lifecycle hooks (on_started/on_stopped/on_restart)
- ActorRef with tell (fire-and-forget) and ask (request-response)
- Supervision: OneForOne/AllForOne strategies with restart limits
- Middleware pipeline for cross-cutting concerns
- Pluggable Mailbox interface (MemoryMailbox default, RedisMailbox optional)
- ReplyRegistry + ReplyChannel: ask() works across any mailbox backend
- System-level thread pool for blocking I/O (run_in_executor)
- Dead letter handling, poison message quarantine, parallel shutdown
- 22 tests + benchmark suite
2026-03-30 23:50:54 +08:00

443 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Tests for the async Actor framework."""
import asyncio
import pytest
from deerflow.actor import (
Actor,
ActorRef,
ActorSystem,
AllForOneStrategy,
Directive,
Middleware,
OneForOneStrategy,
)
from deerflow.actor.ref import ActorStoppedError
# ---------------------------------------------------------------------------
# Basic actors for testing
# ---------------------------------------------------------------------------
class EchoActor(Actor):
async def on_receive(self, message):
return message
class CounterActor(Actor):
async def on_started(self):
self.count = 0
async def on_receive(self, message):
if message == "inc":
self.count += 1
elif message == "get":
return self.count
class CrashActor(Actor):
async def on_receive(self, message):
if message == "crash":
raise ValueError("boom")
return "ok"
class ParentActor(Actor):
def __init__(self):
self.child_ref: ActorRef | None = None
self.restarts = 0
def supervisor_strategy(self):
return OneForOneStrategy(max_restarts=3, within_seconds=60)
async def on_started(self):
self.child_ref = await self.context.spawn(CrashActor, "child")
async def on_receive(self, message):
if message == "get_child":
return self.child_ref
class StopOnCrashParent(Actor):
def supervisor_strategy(self):
return OneForOneStrategy(decider=lambda _: Directive.stop)
async def on_started(self):
self.child_ref = await self.context.spawn(CrashActor, "child")
async def on_receive(self, message):
if message == "get_child":
return self.child_ref
class AllForOneParent(Actor):
def supervisor_strategy(self):
return AllForOneStrategy(max_restarts=2, within_seconds=60)
async def on_started(self):
self.c1 = await self.context.spawn(CounterActor, "c1")
self.c2 = await self.context.spawn(CrashActor, "c2")
async def on_receive(self, message):
if message == "get_children":
return (self.c1, self.c2)
class LifecycleActor(Actor):
started = False
stopped = False
restarted_with: Exception | None = None
async def on_started(self):
LifecycleActor.started = True
async def on_stopped(self):
LifecycleActor.stopped = True
async def on_restart(self, error):
LifecycleActor.restarted_with = error
async def on_receive(self, message):
if message == "crash":
raise RuntimeError("lifecycle crash")
return "alive"
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestBasicMessaging:
@pytest.mark.anyio
async def test_tell_and_ask(self):
system = ActorSystem("test")
ref = await system.spawn(EchoActor, "echo")
result = await ref.ask("hello")
assert result == "hello"
await system.shutdown()
@pytest.mark.anyio
async def test_ask_timeout(self):
class SlowActor(Actor):
async def on_receive(self, message):
await asyncio.sleep(10)
system = ActorSystem("test")
ref = await system.spawn(SlowActor, "slow")
with pytest.raises(asyncio.TimeoutError):
await ref.ask("hi", timeout=0.1)
await system.shutdown()
@pytest.mark.anyio
async def test_tell_fire_and_forget(self):
system = ActorSystem("test")
ref = await system.spawn(CounterActor, "counter")
await ref.tell("inc")
await ref.tell("inc")
await ref.tell("inc")
# Give the actor time to process
await asyncio.sleep(0.05)
count = await ref.ask("get")
assert count == 3
await system.shutdown()
@pytest.mark.anyio
async def test_ask_stopped_actor(self):
system = ActorSystem("test")
ref = await system.spawn(EchoActor, "echo")
ref.stop()
await asyncio.sleep(0.05)
with pytest.raises(ActorStoppedError):
await ref.ask("hello")
await system.shutdown()
@pytest.mark.anyio
async def test_tell_stopped_actor_goes_to_dead_letters(self):
system = ActorSystem("test")
ref = await system.spawn(EchoActor, "echo")
ref.stop()
await asyncio.sleep(0.05)
await ref.tell("orphan")
assert len(system.dead_letters) >= 1
await system.shutdown()
class TestActorPath:
@pytest.mark.anyio
async def test_root_actor_path(self):
system = ActorSystem("app")
ref = await system.spawn(EchoActor, "echo")
assert ref.path == "/app/echo"
await system.shutdown()
@pytest.mark.anyio
async def test_child_actor_path(self):
system = ActorSystem("app")
parent = await system.spawn(ParentActor, "parent")
child: ActorRef = await parent.ask("get_child")
assert child.path == "/app/parent/child"
await system.shutdown()
class TestLifecycle:
@pytest.mark.anyio
async def test_on_started_called(self):
LifecycleActor.started = False
system = ActorSystem("test")
await system.spawn(LifecycleActor, "lc")
assert LifecycleActor.started is True
await system.shutdown()
@pytest.mark.anyio
async def test_on_stopped_called(self):
LifecycleActor.stopped = False
system = ActorSystem("test")
ref = await system.spawn(LifecycleActor, "lc")
ref.stop()
await asyncio.sleep(0.1)
assert LifecycleActor.stopped is True
await system.shutdown()
@pytest.mark.anyio
async def test_shutdown_stops_all(self):
system = ActorSystem("test")
r1 = await system.spawn(EchoActor, "a")
r2 = await system.spawn(EchoActor, "b")
await system.shutdown()
assert not r1.is_alive
assert not r2.is_alive
class TestSupervision:
@pytest.mark.anyio
async def test_restart_on_crash(self):
system = ActorSystem("test")
parent = await system.spawn(ParentActor, "parent")
child: ActorRef = await parent.ask("get_child")
# Crash the child
with pytest.raises(ValueError, match="boom"):
await child.ask("crash")
await asyncio.sleep(0.1)
# Child should still be alive (restarted)
assert child.is_alive
result = await child.ask("safe")
assert result == "ok"
await system.shutdown()
@pytest.mark.anyio
async def test_stop_directive(self):
system = ActorSystem("test")
parent = await system.spawn(StopOnCrashParent, "parent")
child: ActorRef = await parent.ask("get_child")
with pytest.raises(ValueError, match="boom"):
await child.ask("crash")
await asyncio.sleep(0.1)
assert not child.is_alive
await system.shutdown()
@pytest.mark.anyio
async def test_restart_limit_exceeded(self):
system = ActorSystem("test")
class StrictParent(Actor):
def supervisor_strategy(self):
return OneForOneStrategy(max_restarts=2, within_seconds=60)
async def on_started(self):
self.child_ref = await self.context.spawn(CrashActor, "child")
async def on_receive(self, message):
return self.child_ref
parent = await system.spawn(StrictParent, "parent")
child: ActorRef = await parent.ask("any")
# Exhaust restart limit
for _ in range(3):
try:
await child.ask("crash")
except (ValueError, ActorStoppedError):
pass
await asyncio.sleep(0.05)
# After exceeding limit, child should be stopped
assert not child.is_alive
await system.shutdown()
@pytest.mark.anyio
async def test_all_for_one_restarts_siblings(self):
system = ActorSystem("test")
parent = await system.spawn(AllForOneParent, "parent")
c1, c2 = await parent.ask("get_children")
# Increment counter on c1
await c1.tell("inc")
await asyncio.sleep(0.05)
count_before = await c1.ask("get")
assert count_before == 1
# Crash c2 → AllForOne should restart both
try:
await c2.ask("crash")
except ValueError:
pass
await asyncio.sleep(0.1)
# c1 was restarted, counter should be 0
count_after = await c1.ask("get")
assert count_after == 0
await system.shutdown()
class TestDeadLetters:
@pytest.mark.anyio
async def test_dead_letter_callback(self):
received = []
system = ActorSystem("test")
system.on_dead_letter(lambda dl: received.append(dl))
ref = await system.spawn(EchoActor, "echo")
ref.stop()
await asyncio.sleep(0.05)
await ref.tell("orphan")
assert len(received) >= 1
assert received[-1].message == "orphan"
await system.shutdown()
class TestDuplicateNames:
@pytest.mark.anyio
async def test_duplicate_root_name_raises(self):
system = ActorSystem("test")
await system.spawn(EchoActor, "echo")
with pytest.raises(ValueError, match="already exists"):
await system.spawn(EchoActor, "echo")
await system.shutdown()
# ---------------------------------------------------------------------------
# Middleware tests
# ---------------------------------------------------------------------------
class LogMiddleware(Middleware):
def __init__(self):
self.log: list[str] = []
async def on_receive(self, ctx, message, next_fn):
self.log.append(f"before:{message}")
result = await next_fn(ctx, message)
self.log.append(f"after:{result}")
return result
async def on_started(self, actor_ref):
self.log.append("started")
async def on_stopped(self, actor_ref):
self.log.append("stopped")
class TransformMiddleware(Middleware):
"""Uppercases string messages before passing to actor."""
async def on_receive(self, ctx, message, next_fn):
if isinstance(message, str):
message = message.upper()
return await next_fn(ctx, message)
class TestExecutor:
@pytest.mark.anyio
async def test_run_in_executor(self):
"""Blocking function runs in thread pool without blocking event loop."""
import time
class BlockingActor(Actor):
async def on_receive(self, message):
# Simulate blocking I/O via thread pool
result = await self.context.run_in_executor(time.sleep, 0.01)
return "done"
system = ActorSystem("test", executor_workers=2)
ref = await system.spawn(BlockingActor, "blocker")
result = await ref.ask("go", timeout=5.0)
assert result == "done"
await system.shutdown()
@pytest.mark.anyio
async def test_concurrent_blocking_calls(self):
"""Multiple actors can run blocking I/O concurrently via shared pool."""
import time
class SlowActor(Actor):
async def on_receive(self, message):
await self.context.run_in_executor(time.sleep, 0.1)
return "ok"
system = ActorSystem("test", executor_workers=4)
refs = [await system.spawn(SlowActor, f"s{i}") for i in range(4)]
start = time.monotonic()
results = await asyncio.gather(*[r.ask("go", timeout=5.0) for r in refs])
elapsed = time.monotonic() - start
assert all(r == "ok" for r in results)
# 4 parallel × 0.1s should finish in ~0.1-0.2s, not 0.4s
assert elapsed < 0.3
await system.shutdown()
class TestMiddleware:
@pytest.mark.anyio
async def test_middleware_intercepts_messages(self):
mw = LogMiddleware()
system = ActorSystem("test")
ref = await system.spawn(EchoActor, "echo", middlewares=[mw])
result = await ref.ask("hello")
assert result == "hello"
assert "before:hello" in mw.log
assert "after:hello" in mw.log
await system.shutdown()
@pytest.mark.anyio
async def test_middleware_lifecycle_hooks(self):
mw = LogMiddleware()
system = ActorSystem("test")
ref = await system.spawn(EchoActor, "echo", middlewares=[mw])
assert "started" in mw.log
ref.stop()
await asyncio.sleep(0.1)
assert "stopped" in mw.log
await system.shutdown()
@pytest.mark.anyio
async def test_middleware_chain_order(self):
"""First middleware wraps outermost — sees original message."""
mw1 = LogMiddleware()
mw2 = TransformMiddleware()
system = ActorSystem("test")
# Chain: mw1(mw2(actor)). mw1 logs original, mw2 uppercases, actor echoes
ref = await system.spawn(EchoActor, "echo", middlewares=[mw1, mw2])
result = await ref.ask("hello")
assert result == "HELLO" # TransformMiddleware uppercased
assert "before:hello" in mw1.log # LogMiddleware saw original
assert "after:HELLO" in mw1.log # LogMiddleware saw transformed result
await system.shutdown()
@pytest.mark.anyio
async def test_middleware_with_tell(self):
mw = LogMiddleware()
system = ActorSystem("test")
await system.spawn(CounterActor, "counter", middlewares=[mw])
# tell goes through middleware too
assert any("before:" in entry for entry in mw.log) is False
await system.shutdown()