mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-20 15:11:09 +00:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4383d96583 | |||
| 3e461d9d08 | |||
| cf43584d24 | |||
| 6ff60f2af1 | |||
| a3bfea631c | |||
| aae59a8ba8 | |||
| 3ff15423d6 | |||
| c2f7be37b3 | |||
| 09a9209724 | |||
| b356a13da5 | |||
| ac9a6ee6a2 | |||
| 64e0f5329a |
@@ -2,8 +2,6 @@
|
|||||||
docker/.cache/
|
docker/.cache/
|
||||||
# oh-my-claudecode state
|
# oh-my-claudecode state
|
||||||
.omc/
|
.omc/
|
||||||
# Collaborator plugin state
|
|
||||||
.collaborator/
|
|
||||||
# OS generated files
|
# OS generated files
|
||||||
.DS_Store
|
.DS_Store
|
||||||
*.local
|
*.local
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ class RunCreateRequest(BaseModel):
|
|||||||
command: dict[str, Any] | None = Field(default=None, description="LangGraph Command")
|
command: dict[str, Any] | None = Field(default=None, description="LangGraph Command")
|
||||||
metadata: dict[str, Any] | None = Field(default=None, description="Run metadata")
|
metadata: dict[str, Any] | None = Field(default=None, description="Run metadata")
|
||||||
config: dict[str, Any] | None = Field(default=None, description="RunnableConfig overrides")
|
config: dict[str, Any] | None = Field(default=None, description="RunnableConfig overrides")
|
||||||
|
context: dict[str, Any] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
|
||||||
webhook: str | None = Field(default=None, description="Completion callback URL")
|
webhook: str | None = Field(default=None, description="Completion callback URL")
|
||||||
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
|
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
|
||||||
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -93,20 +94,56 @@ def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
|
|||||||
return raw_input
|
return raw_input
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_ASSISTANT_ID = "lead_agent"
|
||||||
|
|
||||||
|
|
||||||
def resolve_agent_factory(assistant_id: str | None):
|
def resolve_agent_factory(assistant_id: str | None):
|
||||||
"""Resolve the agent factory callable from config."""
|
"""Resolve the agent factory callable from config.
|
||||||
|
|
||||||
|
Custom agents are implemented as ``lead_agent`` + an ``agent_name``
|
||||||
|
injected into ``configurable`` — see :func:`build_run_config`. All
|
||||||
|
``assistant_id`` values therefore map to the same factory; the routing
|
||||||
|
happens inside ``make_lead_agent`` when it reads ``cfg["agent_name"]``.
|
||||||
|
"""
|
||||||
from deerflow.agents.lead_agent.agent import make_lead_agent
|
from deerflow.agents.lead_agent.agent import make_lead_agent
|
||||||
|
|
||||||
if assistant_id and assistant_id != "lead_agent":
|
|
||||||
logger.info("assistant_id=%s requested; falling back to lead_agent", assistant_id)
|
|
||||||
return make_lead_agent
|
return make_lead_agent
|
||||||
|
|
||||||
|
|
||||||
def build_run_config(thread_id: str, request_config: dict[str, Any] | None, metadata: dict[str, Any] | None) -> dict[str, Any]:
|
def build_run_config(
|
||||||
"""Build a RunnableConfig dict for the agent."""
|
thread_id: str,
|
||||||
configurable = {"thread_id": thread_id}
|
request_config: dict[str, Any] | None,
|
||||||
|
metadata: dict[str, Any] | None,
|
||||||
|
*,
|
||||||
|
assistant_id: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Build a RunnableConfig dict for the agent.
|
||||||
|
|
||||||
|
When *assistant_id* refers to a custom agent (anything other than
|
||||||
|
``"lead_agent"`` / ``None``), the name is forwarded as
|
||||||
|
``configurable["agent_name"]``. ``make_lead_agent`` reads this key to
|
||||||
|
load the matching ``agents/<name>/SOUL.md`` and per-agent config —
|
||||||
|
without it the agent silently runs as the default lead agent.
|
||||||
|
|
||||||
|
This mirrors the channel manager's ``_resolve_run_params`` logic so that
|
||||||
|
the LangGraph Platform-compatible HTTP API and the IM channel path behave
|
||||||
|
identically.
|
||||||
|
"""
|
||||||
|
configurable: dict[str, Any] = {"thread_id": thread_id}
|
||||||
if request_config:
|
if request_config:
|
||||||
configurable.update(request_config.get("configurable", {}))
|
configurable.update(request_config.get("configurable", {}))
|
||||||
|
|
||||||
|
# Inject custom agent name when the caller specified a non-default assistant.
|
||||||
|
# Honour an explicit configurable["agent_name"] in the request if already set.
|
||||||
|
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID and "agent_name" not in configurable:
|
||||||
|
# Normalize the same way ChannelManager does: strip, lowercase,
|
||||||
|
# replace underscores with hyphens, then validate to prevent path
|
||||||
|
# traversal and invalid agent directory lookups.
|
||||||
|
normalized = assistant_id.strip().lower().replace("_", "-")
|
||||||
|
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
||||||
|
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
|
||||||
|
configurable["agent_name"] = normalized
|
||||||
|
|
||||||
config: dict[str, Any] = {"configurable": configurable, "recursion_limit": 100}
|
config: dict[str, Any] = {"configurable": configurable, "recursion_limit": 100}
|
||||||
if request_config:
|
if request_config:
|
||||||
for k, v in request_config.items():
|
for k, v in request_config.items():
|
||||||
@@ -233,7 +270,28 @@ async def start_run(
|
|||||||
|
|
||||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
agent_factory = resolve_agent_factory(body.assistant_id)
|
||||||
graph_input = normalize_input(body.input)
|
graph_input = normalize_input(body.input)
|
||||||
config = build_run_config(thread_id, body.config, body.metadata)
|
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||||
|
|
||||||
|
# Merge DeerFlow-specific context overrides into configurable.
|
||||||
|
# The ``context`` field is a custom extension for the langgraph-compat layer
|
||||||
|
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||||
|
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||||||
|
context = getattr(body, "context", None)
|
||||||
|
if context:
|
||||||
|
_CONTEXT_CONFIGURABLE_KEYS = {
|
||||||
|
"model_name",
|
||||||
|
"mode",
|
||||||
|
"thinking_enabled",
|
||||||
|
"reasoning_effort",
|
||||||
|
"is_plan_mode",
|
||||||
|
"subagent_enabled",
|
||||||
|
"max_concurrent_subagents",
|
||||||
|
}
|
||||||
|
configurable = config.setdefault("configurable", {})
|
||||||
|
for key in _CONTEXT_CONFIGURABLE_KEYS:
|
||||||
|
if key in context:
|
||||||
|
configurable.setdefault(key, context[key])
|
||||||
|
|
||||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
stream_modes = normalize_stream_modes(body.stream_mode)
|
||||||
|
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
|
|||||||
@@ -257,6 +257,8 @@ sandbox:
|
|||||||
read_only: false
|
read_only: false
|
||||||
```
|
```
|
||||||
|
|
||||||
|
When you configure `sandbox.mounts`, DeerFlow exposes those `container_path` values in the agent prompt so the agent can discover and operate on mounted directories directly instead of assuming everything must live under `/mnt/user-data`.
|
||||||
|
|
||||||
### Skills
|
### Skills
|
||||||
|
|
||||||
Configure the skills directory for specialized workflows:
|
Configure the skills directory for specialized workflows:
|
||||||
|
|||||||
@@ -1,46 +0,0 @@
|
|||||||
"""Async Actor framework — lightweight, asyncio-native, supervision-ready.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
from deerflow.actor import Actor, ActorSystem
|
|
||||||
|
|
||||||
class Greeter(Actor):
|
|
||||||
async def on_receive(self, message):
|
|
||||||
return f"Hello, {message}!"
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
system = ActorSystem("app")
|
|
||||||
ref = await system.spawn(Greeter, "greeter")
|
|
||||||
reply = await ref.ask("World", timeout=5.0)
|
|
||||||
print(reply) # Hello, World!
|
|
||||||
await system.shutdown()
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .actor import Actor, ActorContext
|
|
||||||
from .mailbox import Mailbox, MemoryMailbox
|
|
||||||
from .middleware import Middleware
|
|
||||||
from .ref import ActorRef, MailboxFullError, ReplyChannel
|
|
||||||
from .retry import IdempotentActorMixin, IdempotencyStore, RetryEnvelope, ask_with_retry
|
|
||||||
from .supervision import AllForOneStrategy, Directive, OneForOneStrategy, SupervisorStrategy
|
|
||||||
from .system import ActorSystem, DeadLetter
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Actor",
|
|
||||||
"ActorContext",
|
|
||||||
"ActorRef",
|
|
||||||
"ActorSystem",
|
|
||||||
"AllForOneStrategy",
|
|
||||||
"DeadLetter",
|
|
||||||
"Directive",
|
|
||||||
"Mailbox",
|
|
||||||
"MailboxFullError",
|
|
||||||
"MemoryMailbox",
|
|
||||||
"Middleware",
|
|
||||||
"OneForOneStrategy",
|
|
||||||
"ReplyChannel",
|
|
||||||
"RetryEnvelope",
|
|
||||||
"SupervisorStrategy",
|
|
||||||
"IdempotentActorMixin",
|
|
||||||
"IdempotencyStore",
|
|
||||||
"ask_with_retry",
|
|
||||||
]
|
|
||||||
@@ -1,109 +0,0 @@
|
|||||||
"""Actor base class and per-actor context."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
|
||||||
|
|
||||||
from .supervision import OneForOneStrategy, SupervisorStrategy
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .ref import ActorRef
|
|
||||||
|
|
||||||
# Message type variable — use Actor[MyMsg] for typed actors
|
|
||||||
M = TypeVar("M")
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
|
|
||||||
class ActorContext:
|
|
||||||
"""Per-actor runtime context, injected before ``on_started``.
|
|
||||||
|
|
||||||
Provides access to the actor's identity, parent, children,
|
|
||||||
and the ability to spawn child actors.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ("_cell",)
|
|
||||||
|
|
||||||
def __init__(self, cell: Any) -> None:
|
|
||||||
self._cell = cell
|
|
||||||
|
|
||||||
@property
|
|
||||||
def self_ref(self) -> ActorRef:
|
|
||||||
return self._cell.ref
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parent(self) -> ActorRef | None:
|
|
||||||
p = self._cell.parent
|
|
||||||
return p.ref if p is not None else None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def children(self) -> dict[str, ActorRef]:
|
|
||||||
return {name: c.ref for name, c in self._cell.children.items()}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def system(self) -> Any:
|
|
||||||
return self._cell.system
|
|
||||||
|
|
||||||
async def spawn(
|
|
||||||
self,
|
|
||||||
actor_cls: type[Actor],
|
|
||||||
name: str,
|
|
||||||
*,
|
|
||||||
mailbox_size: int = 256,
|
|
||||||
middlewares: list | None = None,
|
|
||||||
) -> ActorRef:
|
|
||||||
"""Spawn a child actor supervised by this actor."""
|
|
||||||
return await self._cell.spawn_child(actor_cls, name, mailbox_size=mailbox_size, middlewares=middlewares)
|
|
||||||
|
|
||||||
async def run_in_executor(self, fn: Callable[..., Any], *args: Any) -> Any:
|
|
||||||
"""Run a blocking function in the system's thread pool.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
result = await self.context.run_in_executor(requests.get, url)
|
|
||||||
"""
|
|
||||||
import asyncio
|
|
||||||
executor = self._cell.system._executor
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(executor, fn, *args)
|
|
||||||
|
|
||||||
|
|
||||||
class Actor(Generic[M]):
|
|
||||||
"""Base class for all actors.
|
|
||||||
|
|
||||||
Type parameter ``M`` constrains the message type::
|
|
||||||
|
|
||||||
class Greeter(Actor[str]):
|
|
||||||
async def on_receive(self, message: str) -> str:
|
|
||||||
return f"Hello, {message}!"
|
|
||||||
|
|
||||||
class Calculator(Actor[int | tuple[str, int, int]]):
|
|
||||||
async def on_receive(self, message: int | tuple[str, int, int]) -> int:
|
|
||||||
...
|
|
||||||
|
|
||||||
Unparameterized ``Actor`` accepts ``Any`` (backward-compatible).
|
|
||||||
"""
|
|
||||||
|
|
||||||
context: ActorContext
|
|
||||||
|
|
||||||
async def on_receive(self, message: M) -> Any:
|
|
||||||
"""Handle an incoming message.
|
|
||||||
|
|
||||||
Return value is sent back as reply for ``ask`` calls.
|
|
||||||
For ``tell`` calls, the return value is discarded.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def on_started(self) -> None:
|
|
||||||
"""Called after creation, before receiving messages."""
|
|
||||||
|
|
||||||
async def on_stopped(self) -> None:
|
|
||||||
"""Called on graceful shutdown. Release resources here."""
|
|
||||||
|
|
||||||
async def on_restart(self, error: Exception) -> None:
|
|
||||||
"""Called on the *new* instance before resuming after a crash."""
|
|
||||||
|
|
||||||
def supervisor_strategy(self) -> SupervisorStrategy:
|
|
||||||
"""Override to customize how this actor supervises its children.
|
|
||||||
|
|
||||||
Default: OneForOne, up to 3 restarts per 60 seconds, always restart.
|
|
||||||
"""
|
|
||||||
return OneForOneStrategy()
|
|
||||||
@@ -1,121 +0,0 @@
|
|||||||
"""Pluggable mailbox abstraction — Akka-inspired enqueue/dequeue interface.
|
|
||||||
|
|
||||||
Built-in implementations:
|
|
||||||
- ``MemoryMailbox``: asyncio.Queue backed (default)
|
|
||||||
- Extend ``Mailbox`` for Redis, RabbitMQ, Kafka, etc.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import abc
|
|
||||||
import asyncio
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
BACKPRESSURE_BLOCK = "block"
|
|
||||||
BACKPRESSURE_DROP_NEW = "drop_new"
|
|
||||||
BACKPRESSURE_FAIL = "fail"
|
|
||||||
BACKPRESSURE_POLICIES = {BACKPRESSURE_BLOCK, BACKPRESSURE_DROP_NEW, BACKPRESSURE_FAIL}
|
|
||||||
|
|
||||||
|
|
||||||
class Mailbox(abc.ABC):
|
|
||||||
"""Abstract mailbox — the message queue for an actor.
|
|
||||||
|
|
||||||
Implementations must be async-safe for single-consumer usage.
|
|
||||||
Multiple producers may call ``put`` concurrently.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def put(self, msg: Any) -> bool:
|
|
||||||
"""Enqueue a message. Returns True if accepted, False if dropped."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def put_nowait(self, msg: Any) -> bool:
|
|
||||||
"""Non-blocking enqueue. Returns True if accepted, False if dropped."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def get(self) -> Any:
|
|
||||||
"""Dequeue the next message. Blocks until available."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def get_nowait(self) -> Any:
|
|
||||||
"""Non-blocking dequeue. Raises ``Empty`` if no message."""
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def empty(self) -> bool:
|
|
||||||
"""Return True if no messages are queued."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
|
||||||
def full(self) -> bool:
|
|
||||||
"""Return True if mailbox is at capacity."""
|
|
||||||
|
|
||||||
async def put_batch(self, msgs: list[Any]) -> int:
|
|
||||||
"""Enqueue multiple messages. Returns count accepted.
|
|
||||||
|
|
||||||
Default implementation falls back to sequential ``put`` calls.
|
|
||||||
Backends like Redis should override this for efficient bulk push.
|
|
||||||
"""
|
|
||||||
count = 0
|
|
||||||
for msg in msgs:
|
|
||||||
if await self.put(msg):
|
|
||||||
count += 1
|
|
||||||
return count
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""Release resources. Default is no-op."""
|
|
||||||
|
|
||||||
|
|
||||||
class Empty(Exception):
|
|
||||||
"""Raised by ``get_nowait`` when mailbox is empty."""
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryMailbox(Mailbox):
|
|
||||||
"""In-process mailbox backed by ``asyncio.Queue``."""
|
|
||||||
|
|
||||||
def __init__(self, maxsize: int = 256, *, backpressure_policy: str = BACKPRESSURE_BLOCK) -> None:
|
|
||||||
if backpressure_policy not in BACKPRESSURE_POLICIES:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid backpressure_policy={backpressure_policy!r}, "
|
|
||||||
f"expected one of {sorted(BACKPRESSURE_POLICIES)}"
|
|
||||||
)
|
|
||||||
self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=maxsize)
|
|
||||||
self._maxsize = maxsize
|
|
||||||
self._backpressure_policy = backpressure_policy
|
|
||||||
|
|
||||||
async def put(self, msg: Any) -> bool:
|
|
||||||
if self._backpressure_policy == BACKPRESSURE_BLOCK:
|
|
||||||
await self._queue.put(msg)
|
|
||||||
return True
|
|
||||||
if self._backpressure_policy in (BACKPRESSURE_DROP_NEW, BACKPRESSURE_FAIL):
|
|
||||||
if self._queue.full():
|
|
||||||
return False
|
|
||||||
self._queue.put_nowait(msg)
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def put_nowait(self, msg: Any) -> bool:
|
|
||||||
if self._queue.full():
|
|
||||||
return False
|
|
||||||
self._queue.put_nowait(msg)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def get(self) -> Any:
|
|
||||||
return await self._queue.get()
|
|
||||||
|
|
||||||
def get_nowait(self) -> Any:
|
|
||||||
try:
|
|
||||||
return self._queue.get_nowait()
|
|
||||||
except asyncio.QueueEmpty:
|
|
||||||
raise Empty("mailbox empty")
|
|
||||||
|
|
||||||
def empty(self) -> bool:
|
|
||||||
return self._queue.empty()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def full(self) -> bool:
|
|
||||||
return self._queue.full()
|
|
||||||
|
|
||||||
|
|
||||||
# Type alias for mailbox factory
|
|
||||||
MailboxFactory = type[Mailbox] | Any # Callable[[], Mailbox]
|
|
||||||
@@ -1,184 +0,0 @@
|
|||||||
"""Redis-backed mailbox — persistent, survives process restart.
|
|
||||||
|
|
||||||
Requires ``redis[hiredis]`` (``uv add redis[hiredis]``).
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
import redis.asyncio as redis
|
|
||||||
from deerflow.actor import ActorSystem
|
|
||||||
from deerflow.actor.mailbox_redis import RedisMailbox
|
|
||||||
|
|
||||||
pool = redis.ConnectionPool.from_url("redis://localhost:6379")
|
|
||||||
|
|
||||||
system = ActorSystem("app")
|
|
||||||
ref = await system.spawn(
|
|
||||||
MyActor, "worker",
|
|
||||||
mailbox=RedisMailbox(pool, "actor:inbox:worker"),
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from .mailbox import Empty, Mailbox
|
|
||||||
from .ref import _Envelope, _Stop
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize(msg: _Envelope | _Stop) -> str:
|
|
||||||
"""Serialize an envelope to JSON for Redis storage.
|
|
||||||
|
|
||||||
Raises ``TypeError`` if the payload is not JSON-serializable.
|
|
||||||
"""
|
|
||||||
if isinstance(msg, _Stop):
|
|
||||||
return json.dumps({"__type__": "stop"})
|
|
||||||
try:
|
|
||||||
return json.dumps({
|
|
||||||
"__type__": "envelope",
|
|
||||||
"payload": msg.payload,
|
|
||||||
"correlation_id": msg.correlation_id,
|
|
||||||
"reply_to": msg.reply_to,
|
|
||||||
})
|
|
||||||
except (TypeError, ValueError) as e:
|
|
||||||
raise TypeError(f"Payload is not JSON-serializable: {e}. RedisMailbox requires JSON-compatible messages.") from e
|
|
||||||
|
|
||||||
|
|
||||||
def _deserialize(data: str | bytes) -> _Envelope | _Stop:
|
|
||||||
"""Deserialize a JSON string back to an envelope or stop sentinel."""
|
|
||||||
if isinstance(data, bytes):
|
|
||||||
data = data.decode("utf-8")
|
|
||||||
d = json.loads(data)
|
|
||||||
if d.get("__type__") == "stop":
|
|
||||||
return _Stop()
|
|
||||||
return _Envelope(
|
|
||||||
payload=d.get("payload"),
|
|
||||||
sender=None,
|
|
||||||
correlation_id=d.get("correlation_id"),
|
|
||||||
reply_to=d.get("reply_to"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RedisMailbox(Mailbox):
|
|
||||||
"""Mailbox backed by a Redis LIST.
|
|
||||||
|
|
||||||
Each actor gets its own Redis key (the ``queue_name``).
|
|
||||||
Messages are serialized as JSON, so payloads must be JSON-compatible.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pool: A ``redis.asyncio.ConnectionPool`` instance.
|
|
||||||
queue_name: Redis key for this actor's inbox (e.g. ``"actor:inbox:worker"``).
|
|
||||||
maxlen: Maximum queue length. 0 = unbounded. When exceeded, ``put_nowait`` returns False.
|
|
||||||
brpop_timeout: Seconds to block on ``get()`` before retrying. Default 1s.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
pool: Any,
|
|
||||||
queue_name: str,
|
|
||||||
*,
|
|
||||||
maxlen: int = 0,
|
|
||||||
brpop_timeout: float = 1.0,
|
|
||||||
) -> None:
|
|
||||||
self._queue_name = queue_name
|
|
||||||
self._maxlen = maxlen
|
|
||||||
self._brpop_timeout = brpop_timeout
|
|
||||||
self._closed = False
|
|
||||||
# Lazy import to avoid hard dependency on redis
|
|
||||||
try:
|
|
||||||
import redis.asyncio as aioredis
|
|
||||||
self._redis: aioredis.Redis = aioredis.Redis(connection_pool=pool)
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("RedisMailbox requires 'redis' package. Install with: uv add redis[hiredis]")
|
|
||||||
|
|
||||||
# Lua script for atomic bounded push: check length then push
|
|
||||||
_LUA_BOUNDED_PUSH = """
|
|
||||||
if tonumber(ARGV[2]) > 0 and redis.call('llen', KEYS[1]) >= tonumber(ARGV[2]) then
|
|
||||||
return 0
|
|
||||||
end
|
|
||||||
redis.call('lpush', KEYS[1], ARGV[1])
|
|
||||||
return 1
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def put(self, msg: Any) -> bool:
|
|
||||||
if self._closed:
|
|
||||||
return False
|
|
||||||
data = _serialize(msg)
|
|
||||||
try:
|
|
||||||
if self._maxlen > 0:
|
|
||||||
# Atomic check+push via Lua script to avoid TOCTOU race
|
|
||||||
result = await self._redis.eval(self._LUA_BOUNDED_PUSH, 1, self._queue_name, data, self._maxlen)
|
|
||||||
return bool(result)
|
|
||||||
await self._redis.lpush(self._queue_name, data)
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("RedisMailbox.put failed for %s: %s", self._queue_name, e)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def put_nowait(self, msg: Any) -> bool:
|
|
||||||
"""Redis cannot do synchronous non-blocking enqueue reliably.
|
|
||||||
|
|
||||||
Returns False so the caller uses dead-letter or task.cancel() fallback.
|
|
||||||
Use ``put()`` (async) for reliable delivery.
|
|
||||||
"""
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def put_batch(self, msgs: list[Any]) -> int:
|
|
||||||
"""Push multiple messages in a single LPUSH command (one round-trip).
|
|
||||||
|
|
||||||
Unbounded queues: all messages sent atomically in one LPUSH.
|
|
||||||
Bounded queues: sequential puts to respect maxlen (no batch Lua script needed).
|
|
||||||
"""
|
|
||||||
if self._closed or not msgs:
|
|
||||||
return 0
|
|
||||||
data_list = []
|
|
||||||
for msg in msgs:
|
|
||||||
try:
|
|
||||||
data_list.append(_serialize(msg))
|
|
||||||
except TypeError as e:
|
|
||||||
logger.warning("Skipping non-serializable message in put_batch: %s", e)
|
|
||||||
if not data_list:
|
|
||||||
return 0
|
|
||||||
if self._maxlen > 0:
|
|
||||||
count = 0
|
|
||||||
for data in data_list:
|
|
||||||
# Reuse the Lua script for TOCTOU-safe bounded check (same as put())
|
|
||||||
result = await self._redis.eval(self._LUA_BOUNDED_PUSH, 1, self._queue_name, data, self._maxlen)
|
|
||||||
if result:
|
|
||||||
count += 1
|
|
||||||
else:
|
|
||||||
break # queue full — stop early
|
|
||||||
return count
|
|
||||||
# Unbounded: single LPUSH with all values — one network round-trip
|
|
||||||
await self._redis.lpush(self._queue_name, *data_list)
|
|
||||||
return len(data_list)
|
|
||||||
|
|
||||||
async def get(self) -> Any:
|
|
||||||
"""Blocking dequeue via BRPOP. Retries until a message arrives."""
|
|
||||||
while not self._closed:
|
|
||||||
result = await self._redis.brpop(self._queue_name, timeout=self._brpop_timeout)
|
|
||||||
if result is not None:
|
|
||||||
_, data = result
|
|
||||||
return _deserialize(data)
|
|
||||||
raise Empty("mailbox closed")
|
|
||||||
|
|
||||||
def get_nowait(self) -> Any:
|
|
||||||
raise Empty("Redis mailbox does not support synchronous get_nowait")
|
|
||||||
|
|
||||||
def empty(self) -> bool:
|
|
||||||
# Cannot query Redis synchronously. Return True so drain loops
|
|
||||||
# terminate immediately and rely on get_nowait raising Empty.
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def full(self) -> bool:
|
|
||||||
# Cannot query Redis synchronously. Backpressure enforced
|
|
||||||
# atomically inside put() via Lua script.
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
self._closed = True
|
|
||||||
await self._redis.aclose()
|
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
"""Middleware pipeline — cross-cutting concerns for actors.
|
|
||||||
|
|
||||||
Inspired by Proto.Actor's sender/receiver middleware model.
|
|
||||||
Middleware intercepts messages before/after the actor processes them.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
class LoggingMiddleware(Middleware):
|
|
||||||
async def on_receive(self, ctx, message, next_fn):
|
|
||||||
logger.info("Received: %s", message)
|
|
||||||
result = await next_fn(ctx, message)
|
|
||||||
logger.info("Replied: %s", result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
system = ActorSystem("app")
|
|
||||||
ref = await system.spawn(MyActor, "a", middlewares=[LoggingMiddleware()])
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Awaitable, Callable
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
class ActorMailboxContext:
|
|
||||||
"""Context passed to middleware on each message."""
|
|
||||||
|
|
||||||
__slots__ = ("actor_ref", "sender", "message_type")
|
|
||||||
|
|
||||||
def __init__(self, actor_ref: Any, sender: Any, message_type: str) -> None:
|
|
||||||
self.actor_ref = actor_ref
|
|
||||||
self.sender = sender
|
|
||||||
self.message_type = message_type # "tell" or "ask"
|
|
||||||
|
|
||||||
|
|
||||||
# The inner handler signature: (ctx, message) -> result
|
|
||||||
NextFn = Callable[[ActorMailboxContext, Any], Awaitable[Any]]
|
|
||||||
|
|
||||||
|
|
||||||
class Middleware:
|
|
||||||
"""Base class for actor middleware.
|
|
||||||
|
|
||||||
Override ``on_receive`` to intercept inbound messages.
|
|
||||||
Must call ``await next_fn(ctx, message)`` to continue the chain.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def on_receive(self, ctx: ActorMailboxContext, message: Any, next_fn: NextFn) -> Any:
|
|
||||||
"""Intercept a message. Call next_fn to continue the chain."""
|
|
||||||
return await next_fn(ctx, message)
|
|
||||||
|
|
||||||
async def on_started(self, actor_ref: Any) -> None:
|
|
||||||
"""Called when the actor starts."""
|
|
||||||
|
|
||||||
async def on_stopped(self, actor_ref: Any) -> None:
|
|
||||||
"""Called when the actor stops."""
|
|
||||||
|
|
||||||
async def on_restart(self, actor_ref: Any, error: Exception) -> None:
|
|
||||||
"""Called when the actor restarts after a crash.
|
|
||||||
|
|
||||||
Override to reset per-actor-instance state (caches, counters, etc.)
|
|
||||||
that should not bleed across restarts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def build_middleware_chain(middlewares: list[Middleware], handler: NextFn) -> NextFn:
|
|
||||||
"""Build a nested middleware chain ending with *handler*.
|
|
||||||
|
|
||||||
Execution order: first middleware in list wraps outermost.
|
|
||||||
``[A, B, C]`` → ``A(B(C(handler)))``
|
|
||||||
"""
|
|
||||||
chain = handler
|
|
||||||
for mw in reversed(middlewares):
|
|
||||||
outer = chain
|
|
||||||
|
|
||||||
async def _wrap(ctx: ActorMailboxContext, msg: Any, _mw: Middleware = mw, _next: NextFn = outer) -> Any:
|
|
||||||
return await _mw.on_receive(ctx, msg, _next)
|
|
||||||
|
|
||||||
chain = _wrap
|
|
||||||
return chain
|
|
||||||
@@ -1,220 +0,0 @@
|
|||||||
"""ActorRef — immutable, serializable reference to an actor."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import uuid
|
|
||||||
from typing import TYPE_CHECKING, Any
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .system import _ActorCell
|
|
||||||
|
|
||||||
|
|
||||||
class ActorRef:
|
|
||||||
"""Immutable handle for sending messages to an actor.
|
|
||||||
|
|
||||||
Users never construct this directly — it is returned by
|
|
||||||
``ActorSystem.spawn`` or ``ActorContext.spawn``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ("_cell",)
|
|
||||||
|
|
||||||
def __init__(self, cell: _ActorCell) -> None:
|
|
||||||
self._cell = cell
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return self._cell.name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def path(self) -> str:
|
|
||||||
return self._cell.path
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_alive(self) -> bool:
|
|
||||||
return not self._cell.stopped
|
|
||||||
|
|
||||||
async def tell(self, message: Any, *, sender: ActorRef | None = None) -> None:
|
|
||||||
"""Fire-and-forget message delivery."""
|
|
||||||
if self._cell.stopped:
|
|
||||||
self._cell.system._dead_letter(self, message, sender)
|
|
||||||
return
|
|
||||||
await self._cell.enqueue(_Envelope(message, sender))
|
|
||||||
|
|
||||||
async def ask(self, message: Any, *, timeout: float = 5.0) -> Any:
|
|
||||||
"""Request-response with timeout.
|
|
||||||
|
|
||||||
Uses correlation ID + ReplyRegistry instead of passing a Future
|
|
||||||
through the mailbox. This makes ask work with any Mailbox backend
|
|
||||||
(memory, Redis, RabbitMQ, etc.).
|
|
||||||
|
|
||||||
Raises ``asyncio.TimeoutError`` if the actor doesn't reply in time.
|
|
||||||
Raises the actor's exception if ``on_receive`` fails.
|
|
||||||
"""
|
|
||||||
if self._cell.stopped:
|
|
||||||
raise ActorStoppedError(f"Actor {self.path} is stopped")
|
|
||||||
corr_id = uuid.uuid4().hex
|
|
||||||
future = self._cell.system._replies.register(corr_id)
|
|
||||||
try:
|
|
||||||
envelope = _Envelope(message, sender=None, correlation_id=corr_id, reply_to=self._cell.system.system_id)
|
|
||||||
await self._cell.enqueue(envelope)
|
|
||||||
return await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
finally:
|
|
||||||
self._cell.system._replies.discard(corr_id)
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
|
||||||
"""Request graceful shutdown."""
|
|
||||||
self._cell.request_stop()
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
alive = "alive" if self.is_alive else "dead"
|
|
||||||
return f"ActorRef({self.path}, {alive})"
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if isinstance(other, ActorRef):
|
|
||||||
return self._cell is other._cell
|
|
||||||
return NotImplemented
|
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
|
||||||
return id(self._cell)
|
|
||||||
|
|
||||||
|
|
||||||
class ActorStoppedError(Exception):
|
|
||||||
"""Raised when sending to a stopped actor via ask."""
|
|
||||||
|
|
||||||
|
|
||||||
class MailboxFullError(RuntimeError):
|
|
||||||
"""Raised when a message is rejected because the mailbox is at capacity."""
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Internal message wrappers (serializable — no Future objects)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class _Envelope:
|
|
||||||
"""Message envelope flowing through mailboxes.
|
|
||||||
|
|
||||||
All fields are serializable (no asyncio.Future). This is what
|
|
||||||
enables ask() to work across MQ-backed mailboxes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ("payload", "sender", "correlation_id", "reply_to")
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
payload: Any,
|
|
||||||
sender: ActorRef | None = None,
|
|
||||||
correlation_id: str | None = None,
|
|
||||||
reply_to: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.payload = payload
|
|
||||||
self.sender = sender
|
|
||||||
self.correlation_id = correlation_id
|
|
||||||
self.reply_to = reply_to # System ID of the caller (for cross-process reply routing)
|
|
||||||
|
|
||||||
|
|
||||||
class _Stop:
|
|
||||||
"""Sentinel placed on the mailbox to trigger graceful shutdown."""
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# ReplyRegistry — maps correlation_id → Future (lives on ActorSystem)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class _ReplyRegistry:
|
|
||||||
"""In-memory registry mapping correlation IDs to Futures.
|
|
||||||
|
|
||||||
Used by ask() to receive replies without putting Futures in the mailbox.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._pending: dict[str, asyncio.Future[Any]] = {}
|
|
||||||
|
|
||||||
def register(self, corr_id: str) -> asyncio.Future[Any]:
|
|
||||||
"""Create and register a Future for a correlation ID."""
|
|
||||||
future: asyncio.Future[Any] = asyncio.get_running_loop().create_future()
|
|
||||||
self._pending[corr_id] = future
|
|
||||||
return future
|
|
||||||
|
|
||||||
def resolve(self, corr_id: str, result: Any) -> None:
|
|
||||||
"""Complete a pending ask with a result."""
|
|
||||||
future = self._pending.pop(corr_id, None)
|
|
||||||
if future is not None and not future.done():
|
|
||||||
future.set_result(result)
|
|
||||||
|
|
||||||
def reject(self, corr_id: str, error: Exception) -> None:
|
|
||||||
"""Complete a pending ask with an error."""
|
|
||||||
future = self._pending.pop(corr_id, None)
|
|
||||||
if future is not None and not future.done():
|
|
||||||
future.set_exception(error)
|
|
||||||
|
|
||||||
def discard(self, corr_id: str) -> None:
|
|
||||||
"""Remove a pending entry (e.g. on timeout)."""
|
|
||||||
self._pending.pop(corr_id, None)
|
|
||||||
|
|
||||||
def reject_all(self, error: Exception) -> None:
|
|
||||||
"""Reject all pending asks (e.g. on system shutdown)."""
|
|
||||||
for future in self._pending.values():
|
|
||||||
if not future.done():
|
|
||||||
future.set_exception(error)
|
|
||||||
self._pending.clear()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# ReplyChannel — abstraction for routing replies (local or cross-process)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class _ReplyMessage:
|
|
||||||
"""Reply payload sent through ReplyChannel.
|
|
||||||
|
|
||||||
Carries the original exception object for local delivery (preserves type).
|
|
||||||
For cross-process serialization, use ``to_dict``/``from_dict``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ("correlation_id", "result", "error", "exception")
|
|
||||||
|
|
||||||
def __init__(self, correlation_id: str, result: Any = None, error: str | None = None, exception: Exception | None = None) -> None:
|
|
||||||
self.correlation_id = correlation_id
|
|
||||||
self.result = result
|
|
||||||
self.error = error
|
|
||||||
self.exception = exception # Original exception (local only, not serializable)
|
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
|
||||||
"""Serialize for cross-process transport (exception becomes string)."""
|
|
||||||
return {"correlation_id": self.correlation_id, "result": self.result, "error": self.error}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, d: dict[str, Any]) -> _ReplyMessage:
|
|
||||||
return cls(d["correlation_id"], d.get("result"), d.get("error"))
|
|
||||||
|
|
||||||
|
|
||||||
class ReplyChannel:
|
|
||||||
"""Routes replies from actor back to the caller's ReplyRegistry.
|
|
||||||
|
|
||||||
Default implementation: resolve locally (same process).
|
|
||||||
Override ``send_reply`` for cross-process routing (e.g. via Redis pub/sub).
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def send_reply(self, reply_to: str, reply: _ReplyMessage, local_registry: _ReplyRegistry) -> None:
|
|
||||||
"""Deliver a reply to the system identified by *reply_to*.
|
|
||||||
|
|
||||||
Default: assumes reply_to is the local system → resolve directly.
|
|
||||||
Override for MQ-backed cross-process delivery.
|
|
||||||
"""
|
|
||||||
if reply.exception is not None:
|
|
||||||
# Local: preserve original exception type
|
|
||||||
local_registry.reject(reply.correlation_id, reply.exception)
|
|
||||||
elif reply.error is not None:
|
|
||||||
# Cross-process: exception was serialized to string
|
|
||||||
local_registry.reject(reply.correlation_id, RuntimeError(reply.error))
|
|
||||||
else:
|
|
||||||
local_registry.resolve(reply.correlation_id, reply.result)
|
|
||||||
|
|
||||||
async def start_listener(self, system_id: str, registry: _ReplyRegistry) -> None:
|
|
||||||
"""Start listening for inbound replies (no-op for local)."""
|
|
||||||
|
|
||||||
async def stop_listener(self) -> None:
|
|
||||||
"""Stop the reply listener (no-op for local)."""
|
|
||||||
@@ -1,142 +0,0 @@
|
|||||||
"""Retry + idempotency helpers for Actor ask/tell patterns.
|
|
||||||
|
|
||||||
This module provides:
|
|
||||||
- Message envelope carrying retry/idempotency metadata
|
|
||||||
- In-memory idempotency store (process-local)
|
|
||||||
- ask_with_retry helper (bounded retries + exponential backoff + jitter)
|
|
||||||
|
|
||||||
Design notes:
|
|
||||||
- Keep transport-agnostic; works with current in-memory mailbox.
|
|
||||||
- Business handlers must opt in by using ``IdempotentActorMixin`` and
|
|
||||||
wrapping logic with ``handle_idempotent``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import random
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
|
||||||
class RetryEnvelope:
|
|
||||||
"""Metadata wrapper for idempotent/retriable messages."""
|
|
||||||
|
|
||||||
payload: Any
|
|
||||||
message_id: str = field(default_factory=lambda: uuid.uuid4().hex)
|
|
||||||
idempotency_key: str | None = None
|
|
||||||
attempt: int = 1
|
|
||||||
max_attempts: int = 1
|
|
||||||
created_at_ms: int = field(default_factory=lambda: int(time.time() * 1000))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def wrap(
|
|
||||||
cls,
|
|
||||||
payload: Any,
|
|
||||||
*,
|
|
||||||
idempotency_key: str | None = None,
|
|
||||||
attempt: int = 1,
|
|
||||||
max_attempts: int = 1,
|
|
||||||
) -> "RetryEnvelope":
|
|
||||||
return cls(
|
|
||||||
payload=payload,
|
|
||||||
idempotency_key=idempotency_key,
|
|
||||||
attempt=attempt,
|
|
||||||
max_attempts=max_attempts,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class IdempotencyStore:
|
|
||||||
"""Process-local idempotency result store."""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._results: dict[str, Any] = {}
|
|
||||||
|
|
||||||
def has(self, key: str) -> bool:
|
|
||||||
return key in self._results
|
|
||||||
|
|
||||||
def get(self, key: str) -> Any:
|
|
||||||
return self._results[key]
|
|
||||||
|
|
||||||
def set(self, key: str, value: Any) -> None:
|
|
||||||
self._results[key] = value
|
|
||||||
|
|
||||||
|
|
||||||
class IdempotentActorMixin:
|
|
||||||
"""Mixin adding idempotent handling utility for actors.
|
|
||||||
|
|
||||||
Usage in actor::
|
|
||||||
|
|
||||||
class MyActor(IdempotentActorMixin, Actor):
|
|
||||||
async def on_receive(self, message):
|
|
||||||
return await self.handle_idempotent(message, self._handle)
|
|
||||||
|
|
||||||
async def _handle(self, payload):
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _idempotency_store(self) -> IdempotencyStore:
|
|
||||||
store = getattr(self, "_idem_store", None)
|
|
||||||
if store is None:
|
|
||||||
store = IdempotencyStore()
|
|
||||||
setattr(self, "_idem_store", store)
|
|
||||||
return store
|
|
||||||
|
|
||||||
async def handle_idempotent(self, message: Any, handler):
|
|
||||||
if not isinstance(message, RetryEnvelope):
|
|
||||||
return await handler(message)
|
|
||||||
|
|
||||||
key = message.idempotency_key
|
|
||||||
if not key:
|
|
||||||
return await handler(message.payload)
|
|
||||||
|
|
||||||
store = self._idempotency_store()
|
|
||||||
if store.has(key):
|
|
||||||
return store.get(key)
|
|
||||||
|
|
||||||
result = await handler(message.payload)
|
|
||||||
store.set(key, result)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
async def ask_with_retry(
|
|
||||||
ref,
|
|
||||||
payload: Any,
|
|
||||||
*,
|
|
||||||
timeout: float = 5.0,
|
|
||||||
max_attempts: int = 3,
|
|
||||||
base_backoff_s: float = 0.1,
|
|
||||||
max_backoff_s: float = 5.0,
|
|
||||||
jitter_ratio: float = 0.3,
|
|
||||||
retry_exceptions: tuple[type[BaseException], ...] = (asyncio.TimeoutError,),
|
|
||||||
idempotency_key: str | None = None,
|
|
||||||
) -> Any:
|
|
||||||
"""Ask actor with bounded retries and envelope metadata."""
|
|
||||||
if max_attempts < 1:
|
|
||||||
raise ValueError("max_attempts must be >= 1")
|
|
||||||
|
|
||||||
key = idempotency_key or uuid.uuid4().hex
|
|
||||||
last_exc: BaseException | None = None
|
|
||||||
|
|
||||||
for attempt in range(1, max_attempts + 1):
|
|
||||||
msg = RetryEnvelope.wrap(
|
|
||||||
payload,
|
|
||||||
idempotency_key=key,
|
|
||||||
attempt=attempt,
|
|
||||||
max_attempts=max_attempts,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
return await ref.ask(msg, timeout=timeout)
|
|
||||||
except retry_exceptions as exc:
|
|
||||||
last_exc = exc
|
|
||||||
if attempt >= max_attempts:
|
|
||||||
break
|
|
||||||
|
|
||||||
backoff = min(max_backoff_s, base_backoff_s * (2 ** (attempt - 1)))
|
|
||||||
jitter = backoff * jitter_ratio * random.random()
|
|
||||||
await asyncio.sleep(backoff + jitter)
|
|
||||||
|
|
||||||
raise last_exc # type: ignore[misc] # always set: loop runs ≥1 time and sets on last iteration
|
|
||||||
@@ -1,75 +0,0 @@
|
|||||||
"""Supervision strategies — Erlang/Akka-inspired fault tolerance."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import enum
|
|
||||||
import time
|
|
||||||
from collections import deque
|
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
class Directive(enum.Enum):
|
|
||||||
"""What a supervisor should do when a child fails."""
|
|
||||||
|
|
||||||
resume = "resume" # ignore error, keep processing
|
|
||||||
restart = "restart" # discard state, create fresh instance
|
|
||||||
stop = "stop" # terminate the child permanently
|
|
||||||
escalate = "escalate" # propagate to grandparent
|
|
||||||
|
|
||||||
|
|
||||||
class SupervisorStrategy:
|
|
||||||
"""Base class for supervision strategies.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_restarts: Maximum restarts allowed within *within_seconds*.
|
|
||||||
Exceeding this limit stops the child permanently.
|
|
||||||
within_seconds: Time window for restart counting.
|
|
||||||
decider: Maps exception → Directive. Default: always restart.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
max_restarts: int = 3,
|
|
||||||
within_seconds: float = 60.0,
|
|
||||||
decider: Callable[[Exception], Directive] | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.max_restarts = max_restarts
|
|
||||||
self.within_seconds = within_seconds
|
|
||||||
self.decider = decider or (lambda _: Directive.restart)
|
|
||||||
self._restart_timestamps: dict[str, deque[float]] = {}
|
|
||||||
|
|
||||||
def decide(self, error: Exception) -> Directive:
|
|
||||||
return self.decider(error)
|
|
||||||
|
|
||||||
def record_restart(self, child_name: str) -> bool:
|
|
||||||
"""Record a restart and return True if within limits."""
|
|
||||||
now = time.monotonic()
|
|
||||||
if child_name not in self._restart_timestamps:
|
|
||||||
self._restart_timestamps[child_name] = deque()
|
|
||||||
ts = self._restart_timestamps[child_name]
|
|
||||||
# Purge old entries outside the window
|
|
||||||
cutoff = now - self.within_seconds
|
|
||||||
while ts and ts[0] < cutoff:
|
|
||||||
ts.popleft()
|
|
||||||
ts.append(now)
|
|
||||||
return len(ts) <= self.max_restarts
|
|
||||||
|
|
||||||
def apply_to_children(self, failed_child: str, all_children: list[str]) -> list[str]:
|
|
||||||
"""Return which children should be affected by the directive."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class OneForOneStrategy(SupervisorStrategy):
|
|
||||||
"""Only the failed child is affected."""
|
|
||||||
|
|
||||||
def apply_to_children(self, failed_child: str, all_children: list[str]) -> list[str]:
|
|
||||||
return [failed_child]
|
|
||||||
|
|
||||||
|
|
||||||
class AllForOneStrategy(SupervisorStrategy):
|
|
||||||
"""All children are affected when any one fails."""
|
|
||||||
|
|
||||||
def apply_to_children(self, failed_child: str, all_children: list[str]) -> list[str]:
|
|
||||||
return list(all_children)
|
|
||||||
@@ -1,416 +0,0 @@
|
|||||||
"""ActorSystem — top-level actor container and lifecycle manager."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
from collections import deque
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from .actor import Actor, ActorContext
|
|
||||||
from .mailbox import Empty, Mailbox, MemoryMailbox
|
|
||||||
from .middleware import ActorMailboxContext, Middleware, NextFn, build_middleware_chain
|
|
||||||
from .ref import ActorRef, ActorStoppedError, MailboxFullError, ReplyChannel, _Envelope, _ReplyMessage, _ReplyRegistry, _Stop
|
|
||||||
from .supervision import Directive, SupervisorStrategy
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Timeout for middleware lifecycle hooks (on_started/on_stopped)
|
|
||||||
_MIDDLEWARE_HOOK_TIMEOUT = 10.0
|
|
||||||
|
|
||||||
# Maximum dead letters kept in memory
|
|
||||||
_MAX_DEAD_LETTERS = 10000
|
|
||||||
|
|
||||||
# Maximum consecutive failures before a root actor poison-quarantines a message
|
|
||||||
_MAX_CONSECUTIVE_FAILURES = 10
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DeadLetter:
|
|
||||||
"""A message that could not be delivered."""
|
|
||||||
|
|
||||||
recipient: ActorRef
|
|
||||||
message: Any
|
|
||||||
sender: ActorRef | None
|
|
||||||
|
|
||||||
|
|
||||||
class ActorSystem:
|
|
||||||
"""Top-level actor container.
|
|
||||||
|
|
||||||
Manages root actors and provides the dead letter sink.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str = "system",
|
|
||||||
*,
|
|
||||||
max_dead_letters: int = _MAX_DEAD_LETTERS,
|
|
||||||
executor_workers: int | None = 4,
|
|
||||||
reply_channel: ReplyChannel | None = None,
|
|
||||||
) -> None:
|
|
||||||
import uuid as _uuid
|
|
||||||
self.name = name
|
|
||||||
self.system_id = f"{name}-{_uuid.uuid4().hex[:8]}"
|
|
||||||
self._root_cells: dict[str, _ActorCell] = {}
|
|
||||||
self._dead_letters: deque[DeadLetter] = deque(maxlen=max_dead_letters)
|
|
||||||
self._on_dead_letter: list[Any] = []
|
|
||||||
self._shutting_down = False
|
|
||||||
self._replies = _ReplyRegistry()
|
|
||||||
self._reply_channel = reply_channel or ReplyChannel()
|
|
||||||
# Shared thread pool for actors to run blocking I/O
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
self._executor = ThreadPoolExecutor(max_workers=executor_workers, thread_name_prefix=f"actor-{name}") if executor_workers else None
|
|
||||||
|
|
||||||
async def spawn(
|
|
||||||
self,
|
|
||||||
actor_cls: type[Actor],
|
|
||||||
name: str,
|
|
||||||
*,
|
|
||||||
mailbox_size: int = 256,
|
|
||||||
mailbox: Mailbox | None = None,
|
|
||||||
middlewares: list[Middleware] | None = None,
|
|
||||||
) -> ActorRef:
|
|
||||||
"""Spawn a root-level actor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mailbox: Custom mailbox instance. If None, uses MemoryMailbox(mailbox_size).
|
|
||||||
"""
|
|
||||||
if name in self._root_cells:
|
|
||||||
raise ValueError(f"Root actor '{name}' already exists")
|
|
||||||
cell = _ActorCell(
|
|
||||||
actor_cls=actor_cls,
|
|
||||||
name=name,
|
|
||||||
parent=None,
|
|
||||||
system=self,
|
|
||||||
mailbox=mailbox or MemoryMailbox(mailbox_size),
|
|
||||||
middlewares=middlewares or [],
|
|
||||||
)
|
|
||||||
self._root_cells[name] = cell
|
|
||||||
try:
|
|
||||||
await cell.start()
|
|
||||||
except Exception:
|
|
||||||
del self._root_cells[name]
|
|
||||||
raise
|
|
||||||
return cell.ref
|
|
||||||
|
|
||||||
async def shutdown(self, *, timeout: float = 10.0) -> None:
|
|
||||||
"""Gracefully stop all actors."""
|
|
||||||
self._shutting_down = True
|
|
||||||
tasks = []
|
|
||||||
for cell in list(self._root_cells.values()):
|
|
||||||
cell.request_stop()
|
|
||||||
if cell.task is not None:
|
|
||||||
tasks.append(cell.task)
|
|
||||||
if tasks:
|
|
||||||
_, pending = await asyncio.wait(tasks, timeout=timeout)
|
|
||||||
# Cancel tasks that didn't finish within the timeout to prevent zombie tasks
|
|
||||||
for t in pending:
|
|
||||||
t.cancel()
|
|
||||||
if pending:
|
|
||||||
await asyncio.wait(pending, timeout=2.0)
|
|
||||||
self._root_cells.clear()
|
|
||||||
self._replies.reject_all(ActorStoppedError("ActorSystem shutting down"))
|
|
||||||
await self._reply_channel.stop_listener()
|
|
||||||
if self._executor is not None:
|
|
||||||
self._executor.shutdown(wait=False)
|
|
||||||
logger.info("ActorSystem '%s' shut down (%d dead letters)", self.name, len(self._dead_letters))
|
|
||||||
|
|
||||||
def _dead_letter(self, recipient: ActorRef, message: Any, sender: ActorRef | None) -> None:
|
|
||||||
dl = DeadLetter(recipient=recipient, message=message, sender=sender)
|
|
||||||
self._dead_letters.append(dl)
|
|
||||||
for cb in self._on_dead_letter:
|
|
||||||
try:
|
|
||||||
cb(dl)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
logger.debug("Dead letter: %s → %s", type(message).__name__, recipient.path)
|
|
||||||
|
|
||||||
def on_dead_letter(self, callback: Any) -> None:
|
|
||||||
"""Register a dead letter listener."""
|
|
||||||
self._on_dead_letter.append(callback)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dead_letters(self) -> list[DeadLetter]:
|
|
||||||
return list(self._dead_letters)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _ActorCell — internal runtime wrapper
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class _ActorCell:
|
|
||||||
"""Runtime container for a single actor instance.
|
|
||||||
|
|
||||||
Manages the mailbox, processing loop, children, and supervision.
|
|
||||||
Not part of the public API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
actor_cls: type[Actor],
|
|
||||||
name: str,
|
|
||||||
parent: _ActorCell | None,
|
|
||||||
system: ActorSystem,
|
|
||||||
mailbox: Mailbox,
|
|
||||||
middlewares: list[Middleware] | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.actor_cls = actor_cls
|
|
||||||
self.name = name
|
|
||||||
self.parent = parent
|
|
||||||
self.system = system
|
|
||||||
self.children: dict[str, _ActorCell] = {}
|
|
||||||
self.mailbox = mailbox
|
|
||||||
self.ref = ActorRef(self)
|
|
||||||
self.actor: Actor | None = None
|
|
||||||
self.task: asyncio.Task[None] | None = None
|
|
||||||
self.stopped = False
|
|
||||||
self._supervisor_strategy: SupervisorStrategy | None = None
|
|
||||||
self._middlewares = middlewares or []
|
|
||||||
self._receive_chain: NextFn | None = None
|
|
||||||
# Cache path (immutable after init — parent never changes)
|
|
||||||
parts: list[str] = []
|
|
||||||
cell: _ActorCell | None = self
|
|
||||||
while cell is not None:
|
|
||||||
parts.append(cell.name)
|
|
||||||
cell = cell.parent
|
|
||||||
parts.append(system.name)
|
|
||||||
self.path = "/" + "/".join(reversed(parts))
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
|
||||||
self.actor = self.actor_cls()
|
|
||||||
self.actor.context = ActorContext(self)
|
|
||||||
async def _inner_handler(_ctx: ActorMailboxContext, message: Any) -> Any:
|
|
||||||
return await self.actor.on_receive(message) # type: ignore[union-attr]
|
|
||||||
if self._middlewares:
|
|
||||||
self._receive_chain = build_middleware_chain(self._middlewares, _inner_handler)
|
|
||||||
else:
|
|
||||||
self._receive_chain = _inner_handler
|
|
||||||
# Notify middleware of start (with timeout to prevent blocking)
|
|
||||||
for mw in self._middlewares:
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(mw.on_started(self.ref), timeout=_MIDDLEWARE_HOOK_TIMEOUT)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning("Middleware %s.on_started timed out for %s", type(mw).__name__, self.path)
|
|
||||||
await self.actor.on_started()
|
|
||||||
self.task = asyncio.create_task(self._run(), name=f"actor:{self.path}")
|
|
||||||
|
|
||||||
async def enqueue(self, msg: _Envelope | _Stop) -> None:
|
|
||||||
# Try non-blocking first (fast path for MemoryMailbox)
|
|
||||||
if self.mailbox.put_nowait(msg):
|
|
||||||
return
|
|
||||||
# Fallback to async put (required for Redis and other async backends)
|
|
||||||
if not await self.mailbox.put(msg):
|
|
||||||
if isinstance(msg, _Envelope) and msg.correlation_id is not None:
|
|
||||||
self.system._replies.reject(msg.correlation_id, MailboxFullError(f"Mailbox full: {self.path}"))
|
|
||||||
elif isinstance(msg, _Envelope):
|
|
||||||
self.system._dead_letter(self.ref, msg.payload, msg.sender)
|
|
||||||
|
|
||||||
def request_stop(self) -> None:
|
|
||||||
"""Request graceful shutdown.
|
|
||||||
|
|
||||||
Tries put_nowait first. If that fails (full or unsupported backend),
|
|
||||||
cancels the task directly so _run exits via CancelledError → finally → _shutdown.
|
|
||||||
"""
|
|
||||||
if not self.stopped:
|
|
||||||
if not self.mailbox.put_nowait(_Stop()):
|
|
||||||
# Redis/async backends can't put_nowait — cancel the task
|
|
||||||
if self.task is not None and not self.task.done():
|
|
||||||
self.task.cancel()
|
|
||||||
else:
|
|
||||||
self.stopped = True
|
|
||||||
|
|
||||||
async def spawn_child(
|
|
||||||
self,
|
|
||||||
actor_cls: type[Actor],
|
|
||||||
name: str,
|
|
||||||
*,
|
|
||||||
mailbox_size: int = 256,
|
|
||||||
mailbox: Mailbox | None = None,
|
|
||||||
middlewares: list[Middleware] | None = None,
|
|
||||||
) -> ActorRef:
|
|
||||||
if name in self.children:
|
|
||||||
raise ValueError(f"Child '{name}' already exists under {self.path}")
|
|
||||||
child = _ActorCell(
|
|
||||||
actor_cls=actor_cls,
|
|
||||||
name=name,
|
|
||||||
parent=self,
|
|
||||||
system=self.system,
|
|
||||||
mailbox=mailbox or MemoryMailbox(mailbox_size),
|
|
||||||
middlewares=middlewares or [],
|
|
||||||
)
|
|
||||||
self.children[name] = child
|
|
||||||
try:
|
|
||||||
await child.start()
|
|
||||||
except Exception:
|
|
||||||
del self.children[name]
|
|
||||||
raise
|
|
||||||
return child.ref
|
|
||||||
|
|
||||||
# -- Processing loop -------------------------------------------------------
|
|
||||||
|
|
||||||
async def _run(self) -> None:
|
|
||||||
consecutive_failures = 0
|
|
||||||
try:
|
|
||||||
while not self.stopped:
|
|
||||||
try:
|
|
||||||
msg = await self.mailbox.get()
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
break
|
|
||||||
|
|
||||||
if isinstance(msg, _Stop):
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
|
||||||
if not isinstance(msg, _Envelope):
|
|
||||||
continue
|
|
||||||
msg_type = "ask" if msg.correlation_id else "tell"
|
|
||||||
ctx = ActorMailboxContext(self.ref, msg.sender, msg_type)
|
|
||||||
result = await self._receive_chain(ctx, msg.payload) # type: ignore[misc]
|
|
||||||
if msg.correlation_id is not None:
|
|
||||||
reply = _ReplyMessage(msg.correlation_id, result=result)
|
|
||||||
await self.system._reply_channel.send_reply(msg.reply_to or self.system.system_id, reply, self.system._replies)
|
|
||||||
consecutive_failures = 0
|
|
||||||
except Exception as exc:
|
|
||||||
if isinstance(msg, _Envelope) and msg.correlation_id is not None:
|
|
||||||
reply = _ReplyMessage(msg.correlation_id, error=str(exc), exception=exc)
|
|
||||||
await self.system._reply_channel.send_reply(msg.reply_to or self.system.system_id, reply, self.system._replies)
|
|
||||||
if self.parent is not None:
|
|
||||||
await self.parent._handle_child_failure(self, exc)
|
|
||||||
else:
|
|
||||||
consecutive_failures += 1
|
|
||||||
logger.error("Uncaught error in root actor %s (%d/%d): %s", self.path, consecutive_failures, _MAX_CONSECUTIVE_FAILURES, exc)
|
|
||||||
if consecutive_failures >= _MAX_CONSECUTIVE_FAILURES:
|
|
||||||
logger.error("Root actor %s hit consecutive failure limit — stopping", self.path)
|
|
||||||
break
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass # Fall through to _shutdown
|
|
||||||
finally:
|
|
||||||
await self._shutdown()
|
|
||||||
|
|
||||||
async def _shutdown(self) -> None:
|
|
||||||
self.stopped = True
|
|
||||||
# Parallel child shutdown prevents cascading timeouts.
|
|
||||||
child_tasks = []
|
|
||||||
for child in list(self.children.values()):
|
|
||||||
child.request_stop()
|
|
||||||
if child.task is not None:
|
|
||||||
child_tasks.append(child.task)
|
|
||||||
if child_tasks:
|
|
||||||
_, pending = await asyncio.wait(child_tasks, timeout=10.0)
|
|
||||||
for t in pending:
|
|
||||||
t.cancel()
|
|
||||||
# Mark leaked children as stopped
|
|
||||||
for child in self.children.values():
|
|
||||||
if child.task is t:
|
|
||||||
child.stopped = True
|
|
||||||
# Drain mailbox → dead letters (use try/except to handle all backends)
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
msg = self.mailbox.get_nowait()
|
|
||||||
except Empty:
|
|
||||||
break
|
|
||||||
if isinstance(msg, _Envelope):
|
|
||||||
if msg.correlation_id is not None:
|
|
||||||
self.system._replies.reject(msg.correlation_id, ActorStoppedError(f"Actor {self.path} stopped"))
|
|
||||||
else:
|
|
||||||
self.system._dead_letter(self.ref, msg.payload, msg.sender)
|
|
||||||
# Lifecycle hook
|
|
||||||
for mw in self._middlewares:
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(mw.on_stopped(self.ref), timeout=_MIDDLEWARE_HOOK_TIMEOUT)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning("Middleware %s.on_stopped timed out for %s", type(mw).__name__, self.path)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Error in middleware on_stopped for %s", self.path)
|
|
||||||
if self.actor is not None:
|
|
||||||
try:
|
|
||||||
await self.actor.on_stopped()
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Error in on_stopped for %s", self.path)
|
|
||||||
# Remove from parent
|
|
||||||
if self.parent is not None:
|
|
||||||
self.parent.children.pop(self.name, None)
|
|
||||||
# Close mailbox to release backend resources (e.g. Redis connections)
|
|
||||||
try:
|
|
||||||
await self.mailbox.close()
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Error closing mailbox for %s", self.path)
|
|
||||||
|
|
||||||
# -- Supervision -----------------------------------------------------------
|
|
||||||
|
|
||||||
def _get_supervisor_strategy(self) -> SupervisorStrategy:
|
|
||||||
if self._supervisor_strategy is None:
|
|
||||||
self._supervisor_strategy = self.actor.supervisor_strategy() # type: ignore[union-attr]
|
|
||||||
return self._supervisor_strategy
|
|
||||||
|
|
||||||
async def _handle_child_failure(self, child: _ActorCell, error: Exception) -> None:
|
|
||||||
strategy = self._get_supervisor_strategy()
|
|
||||||
directive = strategy.decide(error)
|
|
||||||
|
|
||||||
affected = strategy.apply_to_children(child.name, list(self.children.keys()))
|
|
||||||
|
|
||||||
if directive == Directive.resume:
|
|
||||||
logger.info("Supervisor %s: resume %s after %s", self.path, child.path, type(error).__name__)
|
|
||||||
return
|
|
||||||
|
|
||||||
if directive == Directive.stop:
|
|
||||||
for name in affected:
|
|
||||||
c = self.children.get(name)
|
|
||||||
if c is not None:
|
|
||||||
c.request_stop()
|
|
||||||
logger.info("Supervisor %s: stop %s after %s", self.path, [self.children[n].path for n in affected if n in self.children], type(error).__name__)
|
|
||||||
return
|
|
||||||
|
|
||||||
if directive == Directive.escalate:
|
|
||||||
# Stop the failing child, then propagate failure up the supervision chain.
|
|
||||||
# We cannot use `raise error` here — that would crash the child's _run
|
|
||||||
# loop instead of notifying the grandparent's supervisor.
|
|
||||||
child.request_stop()
|
|
||||||
if self.parent is not None:
|
|
||||||
logger.info("Supervisor %s: escalate %s to grandparent %s", self.path, type(error).__name__, self.parent.path)
|
|
||||||
await self.parent._handle_child_failure(self, error)
|
|
||||||
else:
|
|
||||||
logger.error("Uncaught escalation at root actor %s: %s", self.path, error)
|
|
||||||
return
|
|
||||||
|
|
||||||
if directive == Directive.restart:
|
|
||||||
for name in affected:
|
|
||||||
c = self.children.get(name)
|
|
||||||
if c is None:
|
|
||||||
continue
|
|
||||||
if not strategy.record_restart(name):
|
|
||||||
logger.warning("Supervisor %s: child %s exceeded restart limit — stopping", self.path, c.path)
|
|
||||||
c.request_stop()
|
|
||||||
continue
|
|
||||||
await self._restart_child(c, error)
|
|
||||||
|
|
||||||
async def _restart_child(self, child: _ActorCell, error: Exception) -> None:
|
|
||||||
logger.info("Supervisor %s: restarting %s after %s", self.path, child.path, type(error).__name__)
|
|
||||||
# Stop the old actor (but keep the cell and mailbox)
|
|
||||||
old_actor = child.actor
|
|
||||||
if old_actor is not None:
|
|
||||||
try:
|
|
||||||
await old_actor.on_stopped()
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Error in on_stopped during restart of %s", child.path)
|
|
||||||
|
|
||||||
# Notify middleware of restart (reset per-instance state)
|
|
||||||
for mw in child._middlewares:
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(mw.on_restart(child.ref, error), timeout=_MIDDLEWARE_HOOK_TIMEOUT)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.warning("Middleware %s.on_restart timed out for %s", type(mw).__name__, child.path)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Error in middleware on_restart for %s", child.path)
|
|
||||||
# Create fresh instance
|
|
||||||
new_actor = child.actor_cls()
|
|
||||||
new_actor.context = ActorContext(child)
|
|
||||||
child.actor = new_actor
|
|
||||||
try:
|
|
||||||
await new_actor.on_restart(error)
|
|
||||||
await new_actor.on_started()
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Error during restart initialization of %s", child.path)
|
|
||||||
child.request_stop()
|
|
||||||
@@ -477,6 +477,28 @@ def _build_acp_section() -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_custom_mounts_section() -> str:
|
||||||
|
"""Build a prompt section for explicitly configured sandbox mounts."""
|
||||||
|
try:
|
||||||
|
from deerflow.config import get_app_config
|
||||||
|
|
||||||
|
mounts = get_app_config().sandbox.mounts or []
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if not mounts:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
for mount in mounts:
|
||||||
|
access = "read-only" if mount.read_only else "read-write"
|
||||||
|
lines.append(f"- Custom mount: `{mount.container_path}` - Host directory mapped into the sandbox ({access})")
|
||||||
|
|
||||||
|
mounts_list = "\n".join(lines)
|
||||||
|
return f"\n**Custom Mounted Directories:**\n{mounts_list}\n- If the user needs files outside `/mnt/user-data`, use these absolute container paths directly when they match the requested directory"
|
||||||
|
|
||||||
|
|
||||||
def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagents: int = 3, *, agent_name: str | None = None, available_skills: set[str] | None = None) -> str:
|
def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagents: int = 3, *, agent_name: str | None = None, available_skills: set[str] | None = None) -> str:
|
||||||
# Get memory context
|
# Get memory context
|
||||||
memory_context = _get_memory_context(agent_name)
|
memory_context = _get_memory_context(agent_name)
|
||||||
@@ -511,6 +533,8 @@ def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagen
|
|||||||
|
|
||||||
# Build ACP agent section only if ACP agents are configured
|
# Build ACP agent section only if ACP agents are configured
|
||||||
acp_section = _build_acp_section()
|
acp_section = _build_acp_section()
|
||||||
|
custom_mounts_section = _build_custom_mounts_section()
|
||||||
|
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
|
||||||
|
|
||||||
# Format the prompt with dynamic skills and memory
|
# Format the prompt with dynamic skills and memory
|
||||||
prompt = SYSTEM_PROMPT_TEMPLATE.format(
|
prompt = SYSTEM_PROMPT_TEMPLATE.format(
|
||||||
@@ -522,7 +546,7 @@ def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagen
|
|||||||
subagent_section=subagent_section,
|
subagent_section=subagent_section,
|
||||||
subagent_reminder=subagent_reminder,
|
subagent_reminder=subagent_reminder,
|
||||||
subagent_thinking=subagent_thinking,
|
subagent_thinking=subagent_thinking,
|
||||||
acp_section=acp_section,
|
acp_section=acp_and_mounts_section,
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompt + f"\n<current_date>{datetime.now().strftime('%Y-%m-%d, %A')}</current_date>"
|
return prompt + f"\n<current_date>{datetime.now().strftime('%Y-%m-%d, %A')}</current_date>"
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
|
import shlex
|
||||||
|
import threading
|
||||||
|
import uuid
|
||||||
|
|
||||||
from agent_sandbox import Sandbox as AioSandboxClient
|
from agent_sandbox import Sandbox as AioSandboxClient
|
||||||
|
|
||||||
@@ -7,11 +10,15 @@ from deerflow.sandbox.sandbox import Sandbox
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_ERROR_OBSERVATION_SIGNATURE = "'ErrorObservation' object has no attribute 'exit_code'"
|
||||||
|
|
||||||
|
|
||||||
class AioSandbox(Sandbox):
|
class AioSandbox(Sandbox):
|
||||||
"""Sandbox implementation using the agent-infra/sandbox Docker container.
|
"""Sandbox implementation using the agent-infra/sandbox Docker container.
|
||||||
|
|
||||||
This sandbox connects to a running AIO sandbox container via HTTP API.
|
This sandbox connects to a running AIO sandbox container via HTTP API.
|
||||||
|
A threading lock serializes shell commands to prevent concurrent requests
|
||||||
|
from corrupting the container's single persistent session (see #1433).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, id: str, base_url: str, home_dir: str | None = None):
|
def __init__(self, id: str, base_url: str, home_dir: str | None = None):
|
||||||
@@ -26,6 +33,7 @@ class AioSandbox(Sandbox):
|
|||||||
self._base_url = base_url
|
self._base_url = base_url
|
||||||
self._client = AioSandboxClient(base_url=base_url, timeout=600)
|
self._client = AioSandboxClient(base_url=base_url, timeout=600)
|
||||||
self._home_dir = home_dir
|
self._home_dir = home_dir
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def base_url(self) -> str:
|
def base_url(self) -> str:
|
||||||
@@ -42,19 +50,34 @@ class AioSandbox(Sandbox):
|
|||||||
def execute_command(self, command: str) -> str:
|
def execute_command(self, command: str) -> str:
|
||||||
"""Execute a shell command in the sandbox.
|
"""Execute a shell command in the sandbox.
|
||||||
|
|
||||||
|
Uses a lock to serialize concurrent requests. The AIO sandbox
|
||||||
|
container maintains a single persistent shell session that
|
||||||
|
corrupts when hit with concurrent exec_command calls (returns
|
||||||
|
``ErrorObservation`` instead of real output). If corruption is
|
||||||
|
detected despite the lock (e.g. multiple processes sharing a
|
||||||
|
sandbox), the command is retried on a fresh session.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
command: The command to execute.
|
command: The command to execute.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The output of the command.
|
The output of the command.
|
||||||
"""
|
"""
|
||||||
try:
|
with self._lock:
|
||||||
result = self._client.shell.exec_command(command=command)
|
try:
|
||||||
output = result.data.output if result.data else ""
|
result = self._client.shell.exec_command(command=command)
|
||||||
return output if output else "(no output)"
|
output = result.data.output if result.data else ""
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to execute command in sandbox: {e}")
|
if output and _ERROR_OBSERVATION_SIGNATURE in output:
|
||||||
return f"Error: {e}"
|
logger.warning("ErrorObservation detected in sandbox output, retrying with a fresh session")
|
||||||
|
fresh_id = str(uuid.uuid4())
|
||||||
|
result = self._client.shell.exec_command(command=command, id=fresh_id)
|
||||||
|
output = result.data.output if result.data else ""
|
||||||
|
|
||||||
|
return output if output else "(no output)"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to execute command in sandbox: {e}")
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
def read_file(self, path: str) -> str:
|
def read_file(self, path: str) -> str:
|
||||||
"""Read the content of a file in the sandbox.
|
"""Read the content of a file in the sandbox.
|
||||||
@@ -82,17 +105,16 @@ class AioSandbox(Sandbox):
|
|||||||
Returns:
|
Returns:
|
||||||
The contents of the directory.
|
The contents of the directory.
|
||||||
"""
|
"""
|
||||||
try:
|
with self._lock:
|
||||||
# Use shell command to list directory with depth limit
|
try:
|
||||||
# The -L flag limits the depth for the tree command
|
result = self._client.shell.exec_command(command=f"find {shlex.quote(path)} -maxdepth {max_depth} -type f -o -type d 2>/dev/null | head -500")
|
||||||
result = self._client.shell.exec_command(command=f"find {path} -maxdepth {max_depth} -type f -o -type d 2>/dev/null | head -500")
|
output = result.data.output if result.data else ""
|
||||||
output = result.data.output if result.data else ""
|
if output:
|
||||||
if output:
|
return [line.strip() for line in output.strip().split("\n") if line.strip()]
|
||||||
return [line.strip() for line in output.strip().split("\n") if line.strip()]
|
return []
|
||||||
return []
|
except Exception as e:
|
||||||
except Exception as e:
|
logger.error(f"Failed to list directory in sandbox: {e}")
|
||||||
logger.error(f"Failed to list directory in sandbox: {e}")
|
return []
|
||||||
return []
|
|
||||||
|
|
||||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||||
"""Write content to a file in the sandbox.
|
"""Write content to a file in the sandbox.
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ except ImportError: # pragma: no cover - Windows fallback
|
|||||||
import msvcrt
|
import msvcrt
|
||||||
|
|
||||||
from deerflow.config import get_app_config
|
from deerflow.config import get_app_config
|
||||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, Paths, get_paths
|
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
|
||||||
from deerflow.sandbox.sandbox import Sandbox
|
from deerflow.sandbox.sandbox import Sandbox
|
||||||
from deerflow.sandbox.sandbox_provider import SandboxProvider
|
from deerflow.sandbox.sandbox_provider import SandboxProvider
|
||||||
|
|
||||||
@@ -214,17 +214,13 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
paths.ensure_thread_dirs(thread_id)
|
paths.ensure_thread_dirs(thread_id)
|
||||||
|
|
||||||
# host_paths resolves to the host-side base dir when DEER_FLOW_HOST_BASE_DIR
|
|
||||||
# is set, otherwise falls back to the container's own base dir (native mode).
|
|
||||||
host_paths = Paths(base_dir=paths.host_base_dir)
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
(str(host_paths.sandbox_work_dir(thread_id)), f"{VIRTUAL_PATH_PREFIX}/workspace", False),
|
(paths.host_sandbox_work_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/workspace", False),
|
||||||
(str(host_paths.sandbox_uploads_dir(thread_id)), f"{VIRTUAL_PATH_PREFIX}/uploads", False),
|
(paths.host_sandbox_uploads_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/uploads", False),
|
||||||
(str(host_paths.sandbox_outputs_dir(thread_id)), f"{VIRTUAL_PATH_PREFIX}/outputs", False),
|
(paths.host_sandbox_outputs_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/outputs", False),
|
||||||
# ACP workspace: read-only inside the sandbox (lead agent reads results;
|
# ACP workspace: read-only inside the sandbox (lead agent reads results;
|
||||||
# the ACP subprocess writes from the host side, not from within the container).
|
# the ACP subprocess writes from the host side, not from within the container).
|
||||||
(str(host_paths.acp_workspace_dir(thread_id)), "/mnt/acp-workspace", True),
|
(paths.host_acp_workspace_dir(thread_id), "/mnt/acp-workspace", True),
|
||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -18,6 +18,26 @@ from .sandbox_info import SandboxInfo
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_container_mount(runtime: str, host_path: str, container_path: str, read_only: bool) -> list[str]:
|
||||||
|
"""Format a bind-mount argument for the selected runtime.
|
||||||
|
|
||||||
|
Docker's ``-v host:container`` syntax is ambiguous for Windows drive-letter
|
||||||
|
paths like ``D:/...`` because ``:`` is both the drive separator and the
|
||||||
|
volume separator. Use ``--mount type=bind,...`` for Docker to avoid that
|
||||||
|
parsing ambiguity. Apple Container keeps using ``-v``.
|
||||||
|
"""
|
||||||
|
if runtime == "docker":
|
||||||
|
mount_spec = f"type=bind,src={host_path},dst={container_path}"
|
||||||
|
if read_only:
|
||||||
|
mount_spec += ",readonly"
|
||||||
|
return ["--mount", mount_spec]
|
||||||
|
|
||||||
|
mount_spec = f"{host_path}:{container_path}"
|
||||||
|
if read_only:
|
||||||
|
mount_spec += ":ro"
|
||||||
|
return ["-v", mount_spec]
|
||||||
|
|
||||||
|
|
||||||
class LocalContainerBackend(SandboxBackend):
|
class LocalContainerBackend(SandboxBackend):
|
||||||
"""Backend that manages sandbox containers locally using Docker or Apple Container.
|
"""Backend that manages sandbox containers locally using Docker or Apple Container.
|
||||||
|
|
||||||
@@ -246,18 +266,26 @@ class LocalContainerBackend(SandboxBackend):
|
|||||||
|
|
||||||
# Config-level volume mounts
|
# Config-level volume mounts
|
||||||
for mount in self._config_mounts:
|
for mount in self._config_mounts:
|
||||||
mount_spec = f"{mount.host_path}:{mount.container_path}"
|
cmd.extend(
|
||||||
if mount.read_only:
|
_format_container_mount(
|
||||||
mount_spec += ":ro"
|
self._runtime,
|
||||||
cmd.extend(["-v", mount_spec])
|
mount.host_path,
|
||||||
|
mount.container_path,
|
||||||
|
mount.read_only,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Extra mounts (thread-specific, skills, etc.)
|
# Extra mounts (thread-specific, skills, etc.)
|
||||||
if extra_mounts:
|
if extra_mounts:
|
||||||
for host_path, container_path, read_only in extra_mounts:
|
for host_path, container_path, read_only in extra_mounts:
|
||||||
mount_spec = f"{host_path}:{container_path}"
|
cmd.extend(
|
||||||
if read_only:
|
_format_container_mount(
|
||||||
mount_spec += ":ro"
|
self._runtime,
|
||||||
cmd.extend(["-v", mount_spec])
|
host_path,
|
||||||
|
container_path,
|
||||||
|
read_only,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
cmd.append(self._image)
|
cmd.append(self._image)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path, PureWindowsPath
|
||||||
|
|
||||||
# Virtual path prefix seen by agents inside the sandbox
|
# Virtual path prefix seen by agents inside the sandbox
|
||||||
VIRTUAL_PATH_PREFIX = "/mnt/user-data"
|
VIRTUAL_PATH_PREFIX = "/mnt/user-data"
|
||||||
@@ -9,6 +9,41 @@ VIRTUAL_PATH_PREFIX = "/mnt/user-data"
|
|||||||
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_thread_id(thread_id: str) -> str:
|
||||||
|
"""Validate a thread ID before using it in filesystem paths."""
|
||||||
|
if not _SAFE_THREAD_ID_RE.match(thread_id):
|
||||||
|
raise ValueError(f"Invalid thread_id {thread_id!r}: only alphanumeric characters, hyphens, and underscores are allowed.")
|
||||||
|
return thread_id
|
||||||
|
|
||||||
|
|
||||||
|
def _join_host_path(base: str, *parts: str) -> str:
|
||||||
|
"""Join host filesystem path segments while preserving native style.
|
||||||
|
|
||||||
|
Docker Desktop on Windows expects bind mount sources to stay in Windows
|
||||||
|
path form (for example ``C:\\repo\\backend\\.deer-flow``). Using
|
||||||
|
``Path(base) / ...`` on a POSIX host can accidentally rewrite those paths
|
||||||
|
with mixed separators, so this helper preserves the original style.
|
||||||
|
"""
|
||||||
|
if not parts:
|
||||||
|
return base
|
||||||
|
|
||||||
|
if re.match(r"^[A-Za-z]:[\\/]", base) or base.startswith("\\\\") or "\\" in base:
|
||||||
|
result = PureWindowsPath(base)
|
||||||
|
for part in parts:
|
||||||
|
result /= part
|
||||||
|
return str(result)
|
||||||
|
|
||||||
|
result = Path(base)
|
||||||
|
for part in parts:
|
||||||
|
result /= part
|
||||||
|
return str(result)
|
||||||
|
|
||||||
|
|
||||||
|
def join_host_path(base: str, *parts: str) -> str:
|
||||||
|
"""Join host filesystem path segments while preserving native style."""
|
||||||
|
return _join_host_path(base, *parts)
|
||||||
|
|
||||||
|
|
||||||
class Paths:
|
class Paths:
|
||||||
"""
|
"""
|
||||||
Centralized path configuration for DeerFlow application data.
|
Centralized path configuration for DeerFlow application data.
|
||||||
@@ -54,6 +89,12 @@ class Paths:
|
|||||||
return Path(env)
|
return Path(env)
|
||||||
return self.base_dir
|
return self.base_dir
|
||||||
|
|
||||||
|
def _host_base_dir_str(self) -> str:
|
||||||
|
"""Return the host base dir as a raw string for bind mounts."""
|
||||||
|
if env := os.getenv("DEER_FLOW_HOST_BASE_DIR"):
|
||||||
|
return env
|
||||||
|
return str(self.base_dir)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def base_dir(self) -> Path:
|
def base_dir(self) -> Path:
|
||||||
"""Root directory for all application data."""
|
"""Root directory for all application data."""
|
||||||
@@ -103,9 +144,7 @@ class Paths:
|
|||||||
ValueError: If `thread_id` contains unsafe characters (path separators
|
ValueError: If `thread_id` contains unsafe characters (path separators
|
||||||
or `..`) that could cause directory traversal.
|
or `..`) that could cause directory traversal.
|
||||||
"""
|
"""
|
||||||
if not _SAFE_THREAD_ID_RE.match(thread_id):
|
return self.base_dir / "threads" / _validate_thread_id(thread_id)
|
||||||
raise ValueError(f"Invalid thread_id {thread_id!r}: only alphanumeric characters, hyphens, and underscores are allowed.")
|
|
||||||
return self.base_dir / "threads" / thread_id
|
|
||||||
|
|
||||||
def sandbox_work_dir(self, thread_id: str) -> Path:
|
def sandbox_work_dir(self, thread_id: str) -> Path:
|
||||||
"""
|
"""
|
||||||
@@ -150,6 +189,30 @@ class Paths:
|
|||||||
"""
|
"""
|
||||||
return self.thread_dir(thread_id) / "user-data"
|
return self.thread_dir(thread_id) / "user-data"
|
||||||
|
|
||||||
|
def host_thread_dir(self, thread_id: str) -> str:
|
||||||
|
"""Host path for a thread directory, preserving Windows path syntax."""
|
||||||
|
return _join_host_path(self._host_base_dir_str(), "threads", _validate_thread_id(thread_id))
|
||||||
|
|
||||||
|
def host_sandbox_user_data_dir(self, thread_id: str) -> str:
|
||||||
|
"""Host path for a thread's user-data root."""
|
||||||
|
return _join_host_path(self.host_thread_dir(thread_id), "user-data")
|
||||||
|
|
||||||
|
def host_sandbox_work_dir(self, thread_id: str) -> str:
|
||||||
|
"""Host path for the workspace mount source."""
|
||||||
|
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "workspace")
|
||||||
|
|
||||||
|
def host_sandbox_uploads_dir(self, thread_id: str) -> str:
|
||||||
|
"""Host path for the uploads mount source."""
|
||||||
|
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "uploads")
|
||||||
|
|
||||||
|
def host_sandbox_outputs_dir(self, thread_id: str) -> str:
|
||||||
|
"""Host path for the outputs mount source."""
|
||||||
|
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "outputs")
|
||||||
|
|
||||||
|
def host_acp_workspace_dir(self, thread_id: str) -> str:
|
||||||
|
"""Host path for the ACP workspace mount source."""
|
||||||
|
return _join_host_path(self.host_thread_dir(thread_id), "acp-workspace")
|
||||||
|
|
||||||
def ensure_thread_dirs(self, thread_id: str) -> None:
|
def ensure_thread_dirs(self, thread_id: str) -> None:
|
||||||
"""Create all standard sandbox directories for a thread.
|
"""Create all standard sandbox directories for a thread.
|
||||||
|
|
||||||
|
|||||||
@@ -81,11 +81,9 @@ class RunManager:
|
|||||||
async def list_by_thread(self, thread_id: str) -> list[RunRecord]:
|
async def list_by_thread(self, thread_id: str) -> list[RunRecord]:
|
||||||
"""Return all runs for a given thread, newest first."""
|
"""Return all runs for a given thread, newest first."""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
return sorted(
|
# Dict insertion order matches creation order, so reversing it gives
|
||||||
(r for r in self._runs.values() if r.thread_id == thread_id),
|
# us deterministic newest-first results even when timestamps tie.
|
||||||
key=lambda r: r.created_at,
|
return [r for r in reversed(self._runs.values()) if r.thread_id == thread_id]
|
||||||
reverse=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
|
async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
|
||||||
"""Transition a run to a new status."""
|
"""Transition a run to a new status."""
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ You have access to the sandbox environment:
|
|||||||
- User uploads: `/mnt/user-data/uploads`
|
- User uploads: `/mnt/user-data/uploads`
|
||||||
- User workspace: `/mnt/user-data/workspace`
|
- User workspace: `/mnt/user-data/workspace`
|
||||||
- Output files: `/mnt/user-data/outputs`
|
- Output files: `/mnt/user-data/outputs`
|
||||||
|
- Deployment-configured custom mounts may also be available at other absolute container paths; use them directly when the task references those mounted directories
|
||||||
</working_directory>
|
</working_directory>
|
||||||
""",
|
""",
|
||||||
tools=["bash", "ls", "read_file", "write_file", "str_replace"], # Sandbox tools only
|
tools=["bash", "ls", "read_file", "write_file", "str_replace"], # Sandbox tools only
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ You have access to the same sandbox environment as the parent agent:
|
|||||||
- User uploads: `/mnt/user-data/uploads`
|
- User uploads: `/mnt/user-data/uploads`
|
||||||
- User workspace: `/mnt/user-data/workspace`
|
- User workspace: `/mnt/user-data/workspace`
|
||||||
- Output files: `/mnt/user-data/outputs`
|
- Output files: `/mnt/user-data/outputs`
|
||||||
|
- Deployment-configured custom mounts may also be available at other absolute container paths; use them directly when the task references those mounted directories
|
||||||
</working_directory>
|
</working_directory>
|
||||||
""",
|
""",
|
||||||
tools=None, # Inherit all tools from parent
|
tools=None, # Inherit all tools from parent
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from langgraph.types import Command
|
|||||||
from langgraph.typing import ContextT
|
from langgraph.typing import ContextT
|
||||||
|
|
||||||
from deerflow.agents.thread_state import ThreadState
|
from deerflow.agents.thread_state import ThreadState
|
||||||
from deerflow.sandbox.tools import get_thread_data, replace_virtual_path
|
|
||||||
|
|
||||||
|
|
||||||
@tool("view_image", parse_docstring=True)
|
@tool("view_image", parse_docstring=True)
|
||||||
@@ -32,6 +31,8 @@ def view_image_tool(
|
|||||||
Args:
|
Args:
|
||||||
image_path: Absolute path to the image file. Common formats supported: jpg, jpeg, png, webp.
|
image_path: Absolute path to the image file. Common formats supported: jpg, jpeg, png, webp.
|
||||||
"""
|
"""
|
||||||
|
from deerflow.sandbox.tools import get_thread_data, replace_virtual_path
|
||||||
|
|
||||||
# Replace virtual path with actual path
|
# Replace virtual path with actual path
|
||||||
# /mnt/user-data/* paths are mapped to thread-specific directories
|
# /mnt/user-data/* paths are mapped to thread-specific directories
|
||||||
thread_data = get_thread_data(runtime)
|
thread_data = get_thread_data(runtime)
|
||||||
|
|||||||
@@ -19,11 +19,7 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = ["pytest>=8.0.0", "ruff>=0.14.11"]
|
||||||
"pytest>=8.0.0",
|
|
||||||
"redis>=7.4.0",
|
|
||||||
"ruff>=0.14.11",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.uv.workspace]
|
[tool.uv.workspace]
|
||||||
members = ["packages/harness"]
|
members = ["packages/harness"]
|
||||||
|
|||||||
@@ -1,268 +0,0 @@
|
|||||||
"""Actor framework benchmarks — throughput, latency, concurrency."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
|
|
||||||
from deerflow.actor import Actor, ActorSystem, Middleware
|
|
||||||
|
|
||||||
|
|
||||||
class NoopActor(Actor):
|
|
||||||
async def on_receive(self, message):
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
class CounterActor(Actor):
|
|
||||||
async def on_started(self):
|
|
||||||
self.count = 0
|
|
||||||
|
|
||||||
async def on_receive(self, message):
|
|
||||||
if message == "inc":
|
|
||||||
self.count += 1
|
|
||||||
return self.count
|
|
||||||
if message == "get":
|
|
||||||
return self.count
|
|
||||||
return self.count
|
|
||||||
|
|
||||||
|
|
||||||
class ChainActor(Actor):
|
|
||||||
"""Forwards message to next actor in chain."""
|
|
||||||
next_ref = None
|
|
||||||
|
|
||||||
async def on_receive(self, message):
|
|
||||||
if self.next_ref is not None:
|
|
||||||
return await self.next_ref.ask(message)
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
class ComputeActor(Actor):
|
|
||||||
"""Simulates CPU work via thread pool."""
|
|
||||||
async def on_receive(self, message):
|
|
||||||
def fib(n):
|
|
||||||
a, b = 0, 1
|
|
||||||
for _ in range(n):
|
|
||||||
a, b = b, a + b
|
|
||||||
return a
|
|
||||||
return await self.context.run_in_executor(fib, message)
|
|
||||||
|
|
||||||
|
|
||||||
class CountMiddleware(Middleware):
|
|
||||||
def __init__(self):
|
|
||||||
self.count = 0
|
|
||||||
|
|
||||||
async def on_receive(self, ctx, message, next_fn):
|
|
||||||
self.count += 1
|
|
||||||
return await next_fn(ctx, message)
|
|
||||||
|
|
||||||
|
|
||||||
def fmt(n):
|
|
||||||
if n >= 1_000_000:
|
|
||||||
return f"{n/1_000_000:.1f}M"
|
|
||||||
if n >= 1_000:
|
|
||||||
return f"{n/1_000:.0f}K"
|
|
||||||
return str(n)
|
|
||||||
|
|
||||||
|
|
||||||
async def bench_tell_throughput(n=100_000):
|
|
||||||
"""Measure tell (fire-and-forget) throughput."""
|
|
||||||
system = ActorSystem("bench")
|
|
||||||
ref = await system.spawn(CounterActor, "counter", mailbox_size=n + 10)
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(n):
|
|
||||||
await ref.tell("inc")
|
|
||||||
# Wait for all messages to be processed
|
|
||||||
count = await ref.ask("get", timeout=30.0)
|
|
||||||
if count != n:
|
|
||||||
print(f" warning: expected {n} processed, got {count}")
|
|
||||||
elapsed = time.perf_counter() - start
|
|
||||||
|
|
||||||
await system.shutdown()
|
|
||||||
rate = n / elapsed
|
|
||||||
print(f" tell throughput: {fmt(n)} msgs in {elapsed:.2f}s = {fmt(int(rate))}/s")
|
|
||||||
|
|
||||||
|
|
||||||
async def bench_ask_throughput(n=50_000):
|
|
||||||
"""Measure ask (request-response) throughput."""
|
|
||||||
system = ActorSystem("bench")
|
|
||||||
ref = await system.spawn(NoopActor, "echo")
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(n):
|
|
||||||
await ref.ask("ping")
|
|
||||||
elapsed = time.perf_counter() - start
|
|
||||||
|
|
||||||
await system.shutdown()
|
|
||||||
rate = n / elapsed
|
|
||||||
print(f" ask throughput: {fmt(n)} msgs in {elapsed:.2f}s = {fmt(int(rate))}/s")
|
|
||||||
|
|
||||||
|
|
||||||
async def bench_ask_latency(n=10_000):
|
|
||||||
"""Measure ask round-trip latency percentiles."""
|
|
||||||
system = ActorSystem("bench")
|
|
||||||
ref = await system.spawn(NoopActor, "echo")
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
for _ in range(100):
|
|
||||||
await ref.ask("warmup")
|
|
||||||
|
|
||||||
latencies = []
|
|
||||||
for _ in range(n):
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
await ref.ask("ping")
|
|
||||||
latencies.append((time.perf_counter() - t0) * 1_000_000) # microseconds
|
|
||||||
|
|
||||||
await system.shutdown()
|
|
||||||
latencies.sort()
|
|
||||||
p50 = latencies[len(latencies) // 2]
|
|
||||||
p99 = latencies[int(len(latencies) * 0.99)]
|
|
||||||
p999 = latencies[int(len(latencies) * 0.999)]
|
|
||||||
print(f" ask latency: p50={p50:.0f}µs p99={p99:.0f}µs p99.9={p999:.0f}µs")
|
|
||||||
|
|
||||||
|
|
||||||
async def bench_concurrent_actors(num_actors=1000, msgs_per_actor=100):
|
|
||||||
"""Measure throughput with many concurrent actors."""
|
|
||||||
system = ActorSystem("bench")
|
|
||||||
refs = []
|
|
||||||
for i in range(num_actors):
|
|
||||||
refs.append(await system.spawn(CounterActor, f"a{i}", mailbox_size=msgs_per_actor + 10))
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
|
||||||
|
|
||||||
async def send_batch(ref, n):
|
|
||||||
for i in range(n):
|
|
||||||
await ref.tell("inc")
|
|
||||||
# Yield control every 50 msgs so actor loops can drain
|
|
||||||
if i % 50 == 49:
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
return await ref.ask("get", timeout=30.0)
|
|
||||||
|
|
||||||
results = await asyncio.gather(*[send_batch(r, msgs_per_actor) for r in refs])
|
|
||||||
elapsed = time.perf_counter() - start
|
|
||||||
|
|
||||||
total = num_actors * msgs_per_actor
|
|
||||||
delivered = sum(results)
|
|
||||||
rate = total / elapsed
|
|
||||||
loss = total - delivered
|
|
||||||
print(f" {num_actors} actors × {msgs_per_actor} msgs: {fmt(total)} in {elapsed:.2f}s = {fmt(int(rate))}/s (loss: {loss})")
|
|
||||||
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
async def bench_actor_chain(depth=100):
|
|
||||||
"""Measure ask latency through a chain of actors (hop overhead)."""
|
|
||||||
system = ActorSystem("bench")
|
|
||||||
refs = []
|
|
||||||
for i in range(depth):
|
|
||||||
refs.append(await system.spawn(ChainActor, f"c{i}"))
|
|
||||||
# Link chain: c0 → c1 → ... → c99
|
|
||||||
for i in range(depth - 1):
|
|
||||||
refs[i]._cell.actor.next_ref = refs[i + 1]
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
|
||||||
result = await refs[0].ask("ping", timeout=30.0)
|
|
||||||
elapsed = time.perf_counter() - start
|
|
||||||
|
|
||||||
assert result == "ping"
|
|
||||||
per_hop = elapsed / depth * 1_000_000 # µs
|
|
||||||
print(f" chain {depth} hops: {elapsed*1000:.1f}ms total, {per_hop:.0f}µs/hop")
|
|
||||||
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
async def bench_middleware_overhead(n=50_000):
|
|
||||||
"""Measure overhead of middleware pipeline."""
|
|
||||||
mw = CountMiddleware()
|
|
||||||
|
|
||||||
system_plain = ActorSystem("plain")
|
|
||||||
ref_plain = await system_plain.spawn(NoopActor, "echo")
|
|
||||||
|
|
||||||
system_mw = ActorSystem("mw")
|
|
||||||
ref_mw = await system_mw.spawn(NoopActor, "echo", middlewares=[mw])
|
|
||||||
|
|
||||||
# Plain
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
for _ in range(n):
|
|
||||||
await ref_plain.ask("p")
|
|
||||||
plain_elapsed = time.perf_counter() - t0
|
|
||||||
|
|
||||||
# With middleware
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
for _ in range(n):
|
|
||||||
await ref_mw.ask("p")
|
|
||||||
mw_elapsed = time.perf_counter() - t0
|
|
||||||
|
|
||||||
overhead = ((mw_elapsed - plain_elapsed) / plain_elapsed) * 100
|
|
||||||
print(f" middleware overhead: {overhead:+.1f}% ({fmt(n)} ask calls, 1 middleware)")
|
|
||||||
|
|
||||||
await system_plain.shutdown()
|
|
||||||
await system_mw.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
async def bench_executor_parallel(num_tasks=16):
|
|
||||||
"""Measure thread pool parallelism with CPU work."""
|
|
||||||
system = ActorSystem("bench", executor_workers=8)
|
|
||||||
refs = [await system.spawn(ComputeActor, f"cpu{i}") for i in range(num_tasks)]
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
|
||||||
results = await asyncio.gather(*[r.ask(10_000, timeout=30.0) for r in refs])
|
|
||||||
elapsed = time.perf_counter() - start
|
|
||||||
|
|
||||||
print(f" executor parallel: {num_tasks} fib(10K) in {elapsed*1000:.0f}ms ({num_tasks/elapsed:.0f} tasks/s)")
|
|
||||||
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
async def bench_spawn_teardown(n=5000):
|
|
||||||
"""Measure actor spawn + shutdown speed."""
|
|
||||||
system = ActorSystem("bench")
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
|
||||||
refs = []
|
|
||||||
for i in range(n):
|
|
||||||
refs.append(await system.spawn(NoopActor, f"a{i}"))
|
|
||||||
spawn_elapsed = time.perf_counter() - start
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
|
||||||
await system.shutdown()
|
|
||||||
shutdown_elapsed = time.perf_counter() - start
|
|
||||||
|
|
||||||
print(f" spawn {n}: {spawn_elapsed*1000:.0f}ms ({n/spawn_elapsed:.0f}/s)")
|
|
||||||
print(f" shutdown {n}: {shutdown_elapsed*1000:.0f}ms")
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
print("=" * 60)
|
|
||||||
print(" Actor Framework Benchmarks")
|
|
||||||
print("=" * 60)
|
|
||||||
print()
|
|
||||||
|
|
||||||
print("[Throughput]")
|
|
||||||
await bench_tell_throughput()
|
|
||||||
await bench_ask_throughput()
|
|
||||||
print()
|
|
||||||
|
|
||||||
print("[Latency]")
|
|
||||||
await bench_ask_latency()
|
|
||||||
await bench_actor_chain()
|
|
||||||
print()
|
|
||||||
|
|
||||||
print("[Concurrency]")
|
|
||||||
await bench_concurrent_actors()
|
|
||||||
await bench_executor_parallel()
|
|
||||||
print()
|
|
||||||
|
|
||||||
print("[Overhead]")
|
|
||||||
await bench_middleware_overhead()
|
|
||||||
print()
|
|
||||||
|
|
||||||
print("[Lifecycle]")
|
|
||||||
await bench_spawn_teardown()
|
|
||||||
print()
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print(" Done")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,273 +0,0 @@
|
|||||||
"""RedisMailbox benchmark: throughput, latency, concurrency, backpressure."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
|
|
||||||
import redis.asyncio as redis
|
|
||||||
|
|
||||||
from deerflow.actor import Actor, ActorSystem
|
|
||||||
from deerflow.actor.mailbox_redis import RedisMailbox
|
|
||||||
|
|
||||||
|
|
||||||
class EchoActor(Actor):
|
|
||||||
async def on_receive(self, message):
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
class CounterActor(Actor):
|
|
||||||
async def on_started(self):
|
|
||||||
self.count = 0
|
|
||||||
|
|
||||||
async def on_receive(self, message):
|
|
||||||
if message == "inc":
|
|
||||||
self.count += 1
|
|
||||||
return self.count
|
|
||||||
if message == "get":
|
|
||||||
return self.count
|
|
||||||
return self.count
|
|
||||||
|
|
||||||
|
|
||||||
def fmt(n):
|
|
||||||
if n >= 1_000_000:
|
|
||||||
return f"{n/1_000_000:.1f}M"
|
|
||||||
if n >= 1_000:
|
|
||||||
return f"{n/1_000:.0f}K"
|
|
||||||
return str(n)
|
|
||||||
|
|
||||||
|
|
||||||
async def _redis_client():
|
|
||||||
client = redis.Redis(host="127.0.0.1", port=6379, decode_responses=False)
|
|
||||||
await client.ping()
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
async def bench_redis_ask_throughput(n=20_000):
|
|
||||||
client = await _redis_client()
|
|
||||||
|
|
||||||
queue = "deerflow:bench:redis:ask"
|
|
||||||
await client.delete(queue)
|
|
||||||
|
|
||||||
mailbox = RedisMailbox(client.connection_pool, queue, brpop_timeout=0.05)
|
|
||||||
system = ActorSystem("bench-redis")
|
|
||||||
ref = await system.spawn(EchoActor, "echo", mailbox=mailbox)
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(n):
|
|
||||||
await ref.ask("ping", timeout=5.0)
|
|
||||||
elapsed = time.perf_counter() - start
|
|
||||||
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
rate = n / elapsed
|
|
||||||
print(f" redis ask throughput: {fmt(n)} msgs in {elapsed:.2f}s = {fmt(int(rate))}/s")
|
|
||||||
|
|
||||||
|
|
||||||
async def bench_redis_tell_throughput(n=50_000):
|
|
||||||
client = await _redis_client()
|
|
||||||
|
|
||||||
queue = "deerflow:bench:redis:tell"
|
|
||||||
await client.delete(queue)
|
|
||||||
|
|
||||||
mailbox = RedisMailbox(client.connection_pool, queue, brpop_timeout=0.05)
|
|
||||||
system = ActorSystem("bench-redis")
|
|
||||||
ref = await system.spawn(CounterActor, "counter", mailbox=mailbox)
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(n):
|
|
||||||
await ref.tell("inc")
|
|
||||||
count = await ref.ask("get", timeout=30.0)
|
|
||||||
elapsed = time.perf_counter() - start
|
|
||||||
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
rate = n / elapsed
|
|
||||||
loss = n - count
|
|
||||||
print(f" redis tell throughput: {fmt(n)} msgs in {elapsed:.2f}s = {fmt(int(rate))}/s (loss: {loss})")
|
|
||||||
|
|
||||||
|
|
||||||
async def bench_redis_ask_latency(n=5_000):
|
|
||||||
client = await _redis_client()
|
|
||||||
|
|
||||||
queue = "deerflow:bench:redis:latency"
|
|
||||||
await client.delete(queue)
|
|
||||||
|
|
||||||
mailbox = RedisMailbox(client.connection_pool, queue, brpop_timeout=0.05)
|
|
||||||
system = ActorSystem("bench-redis")
|
|
||||||
ref = await system.spawn(EchoActor, "echo", mailbox=mailbox)
|
|
||||||
|
|
||||||
for _ in range(100):
|
|
||||||
await ref.ask("warmup", timeout=5.0)
|
|
||||||
|
|
||||||
latencies = []
|
|
||||||
for _ in range(n):
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
await ref.ask("ping", timeout=5.0)
|
|
||||||
latencies.append((time.perf_counter() - t0) * 1_000_000)
|
|
||||||
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
latencies.sort()
|
|
||||||
p50 = latencies[len(latencies) // 2]
|
|
||||||
p99 = latencies[int(len(latencies) * 0.99)]
|
|
||||||
p999 = latencies[int(len(latencies) * 0.999)]
|
|
||||||
print(f" redis ask latency: p50={p50:.0f}µs p99={p99:.0f}µs p99.9={p999:.0f}µs")
|
|
||||||
|
|
||||||
|
|
||||||
async def bench_redis_concurrent_actors(num_actors=200, msgs_per_actor=100):
|
|
||||||
client = await _redis_client()
|
|
||||||
system = ActorSystem("bench-redis")
|
|
||||||
refs = []
|
|
||||||
|
|
||||||
for i in range(num_actors):
|
|
||||||
q = f"deerflow:bench:redis:conc:{i}"
|
|
||||||
await client.delete(q)
|
|
||||||
mailbox = RedisMailbox(client.connection_pool, q, brpop_timeout=0.05)
|
|
||||||
refs.append(await system.spawn(CounterActor, f"a{i}", mailbox=mailbox))
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
|
||||||
|
|
||||||
async def send_batch(ref, n):
|
|
||||||
for i in range(n):
|
|
||||||
await ref.tell("inc")
|
|
||||||
if i % 50 == 49:
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
return await ref.ask("get", timeout=30.0)
|
|
||||||
|
|
||||||
results = await asyncio.gather(*[send_batch(r, msgs_per_actor) for r in refs])
|
|
||||||
elapsed = time.perf_counter() - start
|
|
||||||
|
|
||||||
total = num_actors * msgs_per_actor
|
|
||||||
delivered = sum(results)
|
|
||||||
rate = total / elapsed
|
|
||||||
loss = total - delivered
|
|
||||||
print(
|
|
||||||
f" redis concurrency: {num_actors} actors × {msgs_per_actor} msgs = {fmt(total)} in {elapsed:.2f}s = {fmt(int(rate))}/s (loss: {loss})"
|
|
||||||
)
|
|
||||||
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
async def bench_redis_maxlen_backpressure(total_messages=20_000, maxlen=100, ask_timeout=0.01, ask_concurrency=200):
|
|
||||||
client = await _redis_client()
|
|
||||||
|
|
||||||
queue_tell = "deerflow:bench:redis:bp:tell"
|
|
||||||
await client.delete(queue_tell)
|
|
||||||
mailbox_tell = RedisMailbox(client.connection_pool, queue_tell, maxlen=maxlen, brpop_timeout=0.05)
|
|
||||||
|
|
||||||
system_tell = ActorSystem("bench-redis-bp-tell")
|
|
||||||
ref_tell = await system_tell.spawn(CounterActor, "counter", mailbox=mailbox_tell)
|
|
||||||
|
|
||||||
# Saturate with tell: dropped messages become dead letters
|
|
||||||
for _ in range(total_messages):
|
|
||||||
await ref_tell.tell("inc")
|
|
||||||
|
|
||||||
await asyncio.sleep(0.2)
|
|
||||||
processed = await ref_tell.ask("get", timeout=10.0)
|
|
||||||
dropped = len(system_tell.dead_letters)
|
|
||||||
drop_rate = dropped / total_messages if total_messages else 0.0
|
|
||||||
|
|
||||||
print(
|
|
||||||
f" redis maxlen tell: maxlen={maxlen}, sent={fmt(total_messages)}, processed={fmt(processed)}, dropped={fmt(dropped)} ({drop_rate:.1%})"
|
|
||||||
)
|
|
||||||
|
|
||||||
await system_tell.shutdown()
|
|
||||||
|
|
||||||
# Ask timeout rate under pressure
|
|
||||||
queue_ask = "deerflow:bench:redis:bp:ask"
|
|
||||||
await client.delete(queue_ask)
|
|
||||||
mailbox_ask = RedisMailbox(client.connection_pool, queue_ask, maxlen=maxlen, brpop_timeout=0.05)
|
|
||||||
|
|
||||||
system_ask = ActorSystem("bench-redis-bp-ask")
|
|
||||||
ref_ask = await system_ask.spawn(EchoActor, "echo", mailbox=mailbox_ask)
|
|
||||||
|
|
||||||
async def one_ask(i):
|
|
||||||
try:
|
|
||||||
await ref_ask.ask(i, timeout=ask_timeout)
|
|
||||||
return True, None
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
return False, "timeout"
|
|
||||||
except Exception: # MailboxFullError or other rejection
|
|
||||||
return False, "rejected"
|
|
||||||
|
|
||||||
sem = asyncio.Semaphore(ask_concurrency)
|
|
||||||
|
|
||||||
async def one_ask_limited(i):
|
|
||||||
async with sem:
|
|
||||||
return await one_ask(i)
|
|
||||||
|
|
||||||
results = await asyncio.gather(*[one_ask_limited(i) for i in range(total_messages)])
|
|
||||||
ok = sum(1 for r, _ in results if r)
|
|
||||||
timeout_count = sum(1 for _, reason in results if reason == "timeout")
|
|
||||||
rejected_count = sum(1 for _, reason in results if reason == "rejected")
|
|
||||||
fail_rate = (total_messages - ok) / total_messages if total_messages else 0.0
|
|
||||||
|
|
||||||
print(
|
|
||||||
f" redis maxlen ask: maxlen={maxlen}, total={fmt(total_messages)}, ok={fmt(ok)}, "
|
|
||||||
f"timeout={fmt(timeout_count)}, rejected={fmt(rejected_count)} (fail: {fail_rate:.1%}), "
|
|
||||||
f"ask_timeout={ask_timeout}s, concurrency={ask_concurrency}"
|
|
||||||
)
|
|
||||||
|
|
||||||
await system_ask.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
async def bench_redis_put_batch(n=50_000, batch_size=100):
|
|
||||||
"""put_batch: N messages in N/batch_size round-trips instead of N."""
|
|
||||||
client = await _redis_client()
|
|
||||||
|
|
||||||
queue = "deerflow:bench:redis:batch"
|
|
||||||
await client.delete(queue)
|
|
||||||
|
|
||||||
mailbox = RedisMailbox(client.connection_pool, queue, brpop_timeout=0.05)
|
|
||||||
system = ActorSystem("bench-redis-batch")
|
|
||||||
ref = await system.spawn(CounterActor, "counter", mailbox=mailbox)
|
|
||||||
|
|
||||||
from deerflow.actor.ref import _Envelope
|
|
||||||
|
|
||||||
batches = [
|
|
||||||
[_Envelope(payload="inc") for _ in range(batch_size)]
|
|
||||||
for _ in range(n // batch_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
for batch in batches:
|
|
||||||
await mailbox.put_batch(batch)
|
|
||||||
enqueue_elapsed = time.perf_counter() - t0
|
|
||||||
|
|
||||||
count = await ref.ask("get", timeout=60.0)
|
|
||||||
total_elapsed = time.perf_counter() - t0
|
|
||||||
|
|
||||||
loss = n - count
|
|
||||||
enqueue_rate = n / enqueue_elapsed
|
|
||||||
print(
|
|
||||||
f" redis put_batch push: {fmt(n)} msgs in {enqueue_elapsed:.3f}s = {fmt(int(enqueue_rate))}/s "
|
|
||||||
f"(batch={batch_size}, round-trips={n // batch_size})"
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
f" redis put_batch total: end-to-end {total_elapsed:.2f}s = {fmt(int(n / total_elapsed))}/s "
|
|
||||||
f"(consume bottleneck, loss={loss})"
|
|
||||||
)
|
|
||||||
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
print("=" * 72)
|
|
||||||
print(" RedisMailbox Benchmarks")
|
|
||||||
print("=" * 72)
|
|
||||||
print()
|
|
||||||
|
|
||||||
await bench_redis_tell_throughput()
|
|
||||||
await bench_redis_ask_throughput()
|
|
||||||
await bench_redis_ask_latency()
|
|
||||||
await bench_redis_concurrent_actors()
|
|
||||||
await bench_redis_put_batch()
|
|
||||||
await bench_redis_maxlen_backpressure()
|
|
||||||
|
|
||||||
print()
|
|
||||||
print("=" * 72)
|
|
||||||
print(" Done")
|
|
||||||
print("=" * 72)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,534 +0,0 @@
|
|||||||
"""Tests for the async Actor framework."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from deerflow.actor import (
|
|
||||||
Actor,
|
|
||||||
ActorRef,
|
|
||||||
ActorSystem,
|
|
||||||
AllForOneStrategy,
|
|
||||||
Directive,
|
|
||||||
Middleware,
|
|
||||||
OneForOneStrategy,
|
|
||||||
)
|
|
||||||
from deerflow.actor.ref import ActorStoppedError
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Basic actors for testing
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class EchoActor(Actor):
|
|
||||||
async def on_receive(self, message):
|
|
||||||
return message
|
|
||||||
|
|
||||||
|
|
||||||
class CounterActor(Actor):
|
|
||||||
async def on_started(self):
|
|
||||||
self.count = 0
|
|
||||||
|
|
||||||
async def on_receive(self, message):
|
|
||||||
if message == "inc":
|
|
||||||
self.count += 1
|
|
||||||
elif message == "get":
|
|
||||||
return self.count
|
|
||||||
|
|
||||||
|
|
||||||
class CrashActor(Actor):
|
|
||||||
async def on_receive(self, message):
|
|
||||||
if message == "crash":
|
|
||||||
raise ValueError("boom")
|
|
||||||
return "ok"
|
|
||||||
|
|
||||||
|
|
||||||
class ParentActor(Actor):
|
|
||||||
def __init__(self):
|
|
||||||
self.child_ref: ActorRef | None = None
|
|
||||||
self.restarts = 0
|
|
||||||
|
|
||||||
def supervisor_strategy(self):
|
|
||||||
return OneForOneStrategy(max_restarts=3, within_seconds=60)
|
|
||||||
|
|
||||||
async def on_started(self):
|
|
||||||
self.child_ref = await self.context.spawn(CrashActor, "child")
|
|
||||||
|
|
||||||
async def on_receive(self, message):
|
|
||||||
if message == "get_child":
|
|
||||||
return self.child_ref
|
|
||||||
|
|
||||||
|
|
||||||
class StopOnCrashParent(Actor):
|
|
||||||
def supervisor_strategy(self):
|
|
||||||
return OneForOneStrategy(decider=lambda _: Directive.stop)
|
|
||||||
|
|
||||||
async def on_started(self):
|
|
||||||
self.child_ref = await self.context.spawn(CrashActor, "child")
|
|
||||||
|
|
||||||
async def on_receive(self, message):
|
|
||||||
if message == "get_child":
|
|
||||||
return self.child_ref
|
|
||||||
|
|
||||||
|
|
||||||
class AllForOneParent(Actor):
|
|
||||||
def supervisor_strategy(self):
|
|
||||||
return AllForOneStrategy(max_restarts=2, within_seconds=60)
|
|
||||||
|
|
||||||
async def on_started(self):
|
|
||||||
self.c1 = await self.context.spawn(CounterActor, "c1")
|
|
||||||
self.c2 = await self.context.spawn(CrashActor, "c2")
|
|
||||||
|
|
||||||
async def on_receive(self, message):
|
|
||||||
if message == "get_children":
|
|
||||||
return (self.c1, self.c2)
|
|
||||||
|
|
||||||
|
|
||||||
class LifecycleActor(Actor):
|
|
||||||
started = False
|
|
||||||
stopped = False
|
|
||||||
restarted_with: Exception | None = None
|
|
||||||
|
|
||||||
async def on_started(self):
|
|
||||||
LifecycleActor.started = True
|
|
||||||
|
|
||||||
async def on_stopped(self):
|
|
||||||
LifecycleActor.stopped = True
|
|
||||||
|
|
||||||
async def on_restart(self, error):
|
|
||||||
LifecycleActor.restarted_with = error
|
|
||||||
|
|
||||||
async def on_receive(self, message):
|
|
||||||
if message == "crash":
|
|
||||||
raise RuntimeError("lifecycle crash")
|
|
||||||
return "alive"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Tests
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestBasicMessaging:
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_tell_and_ask(self):
|
|
||||||
system = ActorSystem("test")
|
|
||||||
ref = await system.spawn(EchoActor, "echo")
|
|
||||||
result = await ref.ask("hello")
|
|
||||||
assert result == "hello"
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_ask_timeout(self):
|
|
||||||
class SlowActor(Actor):
|
|
||||||
async def on_receive(self, message):
|
|
||||||
await asyncio.sleep(10)
|
|
||||||
|
|
||||||
system = ActorSystem("test")
|
|
||||||
ref = await system.spawn(SlowActor, "slow")
|
|
||||||
with pytest.raises(asyncio.TimeoutError):
|
|
||||||
await ref.ask("hi", timeout=0.1)
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_tell_fire_and_forget(self):
|
|
||||||
system = ActorSystem("test")
|
|
||||||
ref = await system.spawn(CounterActor, "counter")
|
|
||||||
await ref.tell("inc")
|
|
||||||
await ref.tell("inc")
|
|
||||||
await ref.tell("inc")
|
|
||||||
# Give the actor time to process
|
|
||||||
await asyncio.sleep(0.05)
|
|
||||||
count = await ref.ask("get")
|
|
||||||
assert count == 3
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_ask_stopped_actor(self):
|
|
||||||
system = ActorSystem("test")
|
|
||||||
ref = await system.spawn(EchoActor, "echo")
|
|
||||||
ref.stop()
|
|
||||||
await asyncio.sleep(0.05)
|
|
||||||
with pytest.raises(ActorStoppedError):
|
|
||||||
await ref.ask("hello")
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_tell_stopped_actor_goes_to_dead_letters(self):
|
|
||||||
system = ActorSystem("test")
|
|
||||||
ref = await system.spawn(EchoActor, "echo")
|
|
||||||
ref.stop()
|
|
||||||
await asyncio.sleep(0.05)
|
|
||||||
await ref.tell("orphan")
|
|
||||||
assert len(system.dead_letters) >= 1
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
class TestActorPath:
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_root_actor_path(self):
|
|
||||||
system = ActorSystem("app")
|
|
||||||
ref = await system.spawn(EchoActor, "echo")
|
|
||||||
assert ref.path == "/app/echo"
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_child_actor_path(self):
|
|
||||||
system = ActorSystem("app")
|
|
||||||
parent = await system.spawn(ParentActor, "parent")
|
|
||||||
child: ActorRef = await parent.ask("get_child")
|
|
||||||
assert child.path == "/app/parent/child"
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
class TestLifecycle:
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_on_started_called(self):
|
|
||||||
LifecycleActor.started = False
|
|
||||||
system = ActorSystem("test")
|
|
||||||
await system.spawn(LifecycleActor, "lc")
|
|
||||||
assert LifecycleActor.started is True
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_on_stopped_called(self):
|
|
||||||
LifecycleActor.stopped = False
|
|
||||||
system = ActorSystem("test")
|
|
||||||
ref = await system.spawn(LifecycleActor, "lc")
|
|
||||||
ref.stop()
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
assert LifecycleActor.stopped is True
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_shutdown_stops_all(self):
|
|
||||||
system = ActorSystem("test")
|
|
||||||
r1 = await system.spawn(EchoActor, "a")
|
|
||||||
r2 = await system.spawn(EchoActor, "b")
|
|
||||||
await system.shutdown()
|
|
||||||
assert not r1.is_alive
|
|
||||||
assert not r2.is_alive
|
|
||||||
|
|
||||||
|
|
||||||
class TestSupervision:
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_restart_on_crash(self):
|
|
||||||
system = ActorSystem("test")
|
|
||||||
parent = await system.spawn(ParentActor, "parent")
|
|
||||||
child: ActorRef = await parent.ask("get_child")
|
|
||||||
|
|
||||||
# Crash the child
|
|
||||||
with pytest.raises(ValueError, match="boom"):
|
|
||||||
await child.ask("crash")
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
# Child should still be alive (restarted)
|
|
||||||
assert child.is_alive
|
|
||||||
result = await child.ask("safe")
|
|
||||||
assert result == "ok"
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_stop_directive(self):
|
|
||||||
system = ActorSystem("test")
|
|
||||||
parent = await system.spawn(StopOnCrashParent, "parent")
|
|
||||||
child: ActorRef = await parent.ask("get_child")
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="boom"):
|
|
||||||
await child.ask("crash")
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
assert not child.is_alive
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_restart_limit_exceeded(self):
|
|
||||||
system = ActorSystem("test")
|
|
||||||
|
|
||||||
class StrictParent(Actor):
|
|
||||||
def supervisor_strategy(self):
|
|
||||||
return OneForOneStrategy(max_restarts=2, within_seconds=60)
|
|
||||||
|
|
||||||
async def on_started(self):
|
|
||||||
self.child_ref = await self.context.spawn(CrashActor, "child")
|
|
||||||
|
|
||||||
async def on_receive(self, message):
|
|
||||||
return self.child_ref
|
|
||||||
|
|
||||||
parent = await system.spawn(StrictParent, "parent")
|
|
||||||
child: ActorRef = await parent.ask("any")
|
|
||||||
|
|
||||||
# Exhaust restart limit
|
|
||||||
for _ in range(3):
|
|
||||||
try:
|
|
||||||
await child.ask("crash")
|
|
||||||
except (ValueError, ActorStoppedError):
|
|
||||||
pass
|
|
||||||
await asyncio.sleep(0.05)
|
|
||||||
|
|
||||||
# After exceeding limit, child should be stopped
|
|
||||||
assert not child.is_alive
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_all_for_one_restarts_siblings(self):
|
|
||||||
system = ActorSystem("test")
|
|
||||||
parent = await system.spawn(AllForOneParent, "parent")
|
|
||||||
c1, c2 = await parent.ask("get_children")
|
|
||||||
|
|
||||||
# Increment counter on c1
|
|
||||||
await c1.tell("inc")
|
|
||||||
await asyncio.sleep(0.05)
|
|
||||||
count_before = await c1.ask("get")
|
|
||||||
assert count_before == 1
|
|
||||||
|
|
||||||
# Crash c2 → AllForOne should restart both
|
|
||||||
try:
|
|
||||||
await c2.ask("crash")
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
# c1 was restarted, counter should be 0
|
|
||||||
count_after = await c1.ask("get")
|
|
||||||
assert count_after == 0
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
class TestDeadLetters:
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_dead_letter_callback(self):
|
|
||||||
received = []
|
|
||||||
system = ActorSystem("test")
|
|
||||||
system.on_dead_letter(lambda dl: received.append(dl))
|
|
||||||
|
|
||||||
ref = await system.spawn(EchoActor, "echo")
|
|
||||||
ref.stop()
|
|
||||||
await asyncio.sleep(0.05)
|
|
||||||
await ref.tell("orphan")
|
|
||||||
|
|
||||||
assert len(received) >= 1
|
|
||||||
assert received[-1].message == "orphan"
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
class TestDuplicateNames:
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_duplicate_root_name_raises(self):
|
|
||||||
system = ActorSystem("test")
|
|
||||||
await system.spawn(EchoActor, "echo")
|
|
||||||
with pytest.raises(ValueError, match="already exists"):
|
|
||||||
await system.spawn(EchoActor, "echo")
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Middleware tests
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class LogMiddleware(Middleware):
|
|
||||||
def __init__(self):
|
|
||||||
self.log: list[str] = []
|
|
||||||
|
|
||||||
async def on_receive(self, ctx, message, next_fn):
|
|
||||||
self.log.append(f"before:{message}")
|
|
||||||
result = await next_fn(ctx, message)
|
|
||||||
self.log.append(f"after:{result}")
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def on_started(self, actor_ref):
|
|
||||||
self.log.append("started")
|
|
||||||
|
|
||||||
async def on_stopped(self, actor_ref):
|
|
||||||
self.log.append("stopped")
|
|
||||||
|
|
||||||
|
|
||||||
class TransformMiddleware(Middleware):
|
|
||||||
"""Uppercases string messages before passing to actor."""
|
|
||||||
|
|
||||||
async def on_receive(self, ctx, message, next_fn):
|
|
||||||
if isinstance(message, str):
|
|
||||||
message = message.upper()
|
|
||||||
return await next_fn(ctx, message)
|
|
||||||
|
|
||||||
|
|
||||||
class TestExecutor:
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_run_in_executor(self):
|
|
||||||
"""Blocking function runs in thread pool without blocking event loop."""
|
|
||||||
import time
|
|
||||||
|
|
||||||
class BlockingActor(Actor):
|
|
||||||
async def on_receive(self, message):
|
|
||||||
# Simulate blocking I/O via thread pool
|
|
||||||
result = await self.context.run_in_executor(time.sleep, 0.01)
|
|
||||||
return "done"
|
|
||||||
|
|
||||||
system = ActorSystem("test", executor_workers=2)
|
|
||||||
ref = await system.spawn(BlockingActor, "blocker")
|
|
||||||
result = await ref.ask("go", timeout=5.0)
|
|
||||||
assert result == "done"
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_concurrent_blocking_calls(self):
|
|
||||||
"""Multiple actors can run blocking I/O concurrently via shared pool."""
|
|
||||||
import time
|
|
||||||
|
|
||||||
class SlowActor(Actor):
|
|
||||||
async def on_receive(self, message):
|
|
||||||
await self.context.run_in_executor(time.sleep, 0.1)
|
|
||||||
return "ok"
|
|
||||||
|
|
||||||
system = ActorSystem("test", executor_workers=4)
|
|
||||||
refs = [await system.spawn(SlowActor, f"s{i}") for i in range(4)]
|
|
||||||
|
|
||||||
start = time.monotonic()
|
|
||||||
results = await asyncio.gather(*[r.ask("go", timeout=5.0) for r in refs])
|
|
||||||
elapsed = time.monotonic() - start
|
|
||||||
|
|
||||||
assert all(r == "ok" for r in results)
|
|
||||||
# 4 parallel × 0.1s should finish in ~0.1-0.2s, not 0.4s
|
|
||||||
assert elapsed < 0.3
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
class TestMiddleware:
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_middleware_intercepts_messages(self):
|
|
||||||
mw = LogMiddleware()
|
|
||||||
system = ActorSystem("test")
|
|
||||||
ref = await system.spawn(EchoActor, "echo", middlewares=[mw])
|
|
||||||
result = await ref.ask("hello")
|
|
||||||
assert result == "hello"
|
|
||||||
assert "before:hello" in mw.log
|
|
||||||
assert "after:hello" in mw.log
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_middleware_lifecycle_hooks(self):
|
|
||||||
mw = LogMiddleware()
|
|
||||||
system = ActorSystem("test")
|
|
||||||
ref = await system.spawn(EchoActor, "echo", middlewares=[mw])
|
|
||||||
assert "started" in mw.log
|
|
||||||
ref.stop()
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
assert "stopped" in mw.log
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_middleware_chain_order(self):
|
|
||||||
"""First middleware wraps outermost — sees original message."""
|
|
||||||
mw1 = LogMiddleware()
|
|
||||||
mw2 = TransformMiddleware()
|
|
||||||
system = ActorSystem("test")
|
|
||||||
# Chain: mw1(mw2(actor)). mw1 logs original, mw2 uppercases, actor echoes
|
|
||||||
ref = await system.spawn(EchoActor, "echo", middlewares=[mw1, mw2])
|
|
||||||
result = await ref.ask("hello")
|
|
||||||
assert result == "HELLO" # TransformMiddleware uppercased
|
|
||||||
assert "before:hello" in mw1.log # LogMiddleware saw original
|
|
||||||
assert "after:HELLO" in mw1.log # LogMiddleware saw transformed result
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_middleware_with_tell(self):
|
|
||||||
mw = LogMiddleware()
|
|
||||||
system = ActorSystem("test")
|
|
||||||
await system.spawn(CounterActor, "counter", middlewares=[mw])
|
|
||||||
# tell goes through middleware too
|
|
||||||
assert any("before:" in entry for entry in mw.log) is False
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_middleware_on_restart_hook(self):
|
|
||||||
"""on_restart is called on the middleware when a child actor is restarted."""
|
|
||||||
|
|
||||||
class RestartTrackingMiddleware(Middleware):
|
|
||||||
def __init__(self):
|
|
||||||
self.restart_errors: list[Exception] = []
|
|
||||||
|
|
||||||
async def on_restart(self, actor_ref, error):
|
|
||||||
self.restart_errors.append(error)
|
|
||||||
|
|
||||||
mw = RestartTrackingMiddleware()
|
|
||||||
|
|
||||||
class ChildSpawningParent(Actor):
|
|
||||||
async def on_receive(self, message):
|
|
||||||
if message == "spawn":
|
|
||||||
ref = await self.context.spawn(CrashActor, "child", middlewares=[mw])
|
|
||||||
return ref
|
|
||||||
|
|
||||||
system = ActorSystem("test")
|
|
||||||
parent = await system.spawn(ChildSpawningParent, "parent")
|
|
||||||
child = await parent.ask("spawn")
|
|
||||||
|
|
||||||
# Crash the child — parent supervisor will restart it
|
|
||||||
try:
|
|
||||||
await child.ask("crash")
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
assert len(mw.restart_errors) == 1
|
|
||||||
assert isinstance(mw.restart_errors[0], ValueError)
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
class TestAskErrorPropagation:
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_ask_propagates_actor_exception(self):
|
|
||||||
"""ask() re-raises the original exception type when on_receive crashes."""
|
|
||||||
|
|
||||||
class BoomActor(Actor):
|
|
||||||
async def on_receive(self, message):
|
|
||||||
raise ValueError("intentional crash")
|
|
||||||
|
|
||||||
system = ActorSystem("test")
|
|
||||||
ref = await system.spawn(BoomActor, "boom")
|
|
||||||
with pytest.raises(ValueError, match="intentional crash"):
|
|
||||||
await ref.ask("trigger")
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_ask_propagates_exception_while_supervised(self):
|
|
||||||
"""ask() gets the exception even when the actor is supervised (not stopped)."""
|
|
||||||
|
|
||||||
class SometimesCrashActor(Actor):
|
|
||||||
async def on_receive(self, message):
|
|
||||||
if message == "crash":
|
|
||||||
raise RuntimeError("supervised crash")
|
|
||||||
return "ok"
|
|
||||||
|
|
||||||
system = ActorSystem("test")
|
|
||||||
ref = await system.spawn(SometimesCrashActor, "sca")
|
|
||||||
with pytest.raises(RuntimeError, match="supervised crash"):
|
|
||||||
await ref.ask("crash")
|
|
||||||
# Root actor keeps running after a crash (consecutive_failures, not restart)
|
|
||||||
result = await ref.ask("hello", timeout=2.0)
|
|
||||||
assert result == "ok"
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_ask_timeout_late_reply_no_exception(self):
|
|
||||||
"""Late reply arriving after ask() timeout is silently dropped — no exception, no orphaned future."""
|
|
||||||
|
|
||||||
class SlowActor(Actor):
|
|
||||||
async def on_receive(self, message):
|
|
||||||
await asyncio.sleep(0.3)
|
|
||||||
return "late"
|
|
||||||
|
|
||||||
system = ActorSystem("test")
|
|
||||||
ref = await system.spawn(SlowActor, "slow")
|
|
||||||
|
|
||||||
with pytest.raises(asyncio.TimeoutError):
|
|
||||||
await ref.ask("go", timeout=0.05)
|
|
||||||
|
|
||||||
# Wait for actor to finish processing — late reply arrives, should be a no-op
|
|
||||||
await asyncio.sleep(0.4)
|
|
||||||
# System still functional: no orphaned futures, no leaked state
|
|
||||||
assert ref.is_alive
|
|
||||||
result = await ref.ask("go", timeout=2.0)
|
|
||||||
assert result == "late"
|
|
||||||
await system.shutdown()
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from deerflow.actor import Actor, ActorSystem, MailboxFullError
|
|
||||||
from deerflow.actor.mailbox import BACKPRESSURE_BLOCK, BACKPRESSURE_DROP_NEW, BACKPRESSURE_FAIL, MemoryMailbox
|
|
||||||
|
|
||||||
|
|
||||||
class SlowActor(Actor):
|
|
||||||
async def on_started(self):
|
|
||||||
self.count = 0
|
|
||||||
|
|
||||||
async def on_receive(self, message):
|
|
||||||
if message == 'inc':
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
self.count += 1
|
|
||||||
return None
|
|
||||||
if message == 'get':
|
|
||||||
return self.count
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_memory_mailbox_drop_new_policy_drops_tell_to_dead_letters():
|
|
||||||
system = ActorSystem('bp')
|
|
||||||
ref = await system.spawn(
|
|
||||||
SlowActor,
|
|
||||||
'slow',
|
|
||||||
mailbox=MemoryMailbox(1, backpressure_policy=BACKPRESSURE_DROP_NEW),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Overfill quickly
|
|
||||||
for _ in range(20):
|
|
||||||
await ref.tell('inc')
|
|
||||||
|
|
||||||
await asyncio.sleep(0.4)
|
|
||||||
count = await ref.ask('get', timeout=2.0)
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
# Some messages should be dropped under drop_new
|
|
||||||
assert count < 20
|
|
||||||
assert len(system.dead_letters) > 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_memory_mailbox_fail_policy_rejects_ask_when_full():
|
|
||||||
system = ActorSystem('bp')
|
|
||||||
ref = await system.spawn(
|
|
||||||
SlowActor,
|
|
||||||
'slow',
|
|
||||||
mailbox=MemoryMailbox(1, backpressure_policy=BACKPRESSURE_FAIL),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fill queue with tell first
|
|
||||||
await ref.tell('inc')
|
|
||||||
|
|
||||||
# Then ask may be rejected when queue still full
|
|
||||||
got_reject = False
|
|
||||||
for _ in range(30):
|
|
||||||
try:
|
|
||||||
await ref.ask('inc', timeout=0.02)
|
|
||||||
except MailboxFullError:
|
|
||||||
got_reject = True
|
|
||||||
break
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
await system.shutdown()
|
|
||||||
assert got_reject
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_memory_mailbox_block_policy_eventually_accepts():
|
|
||||||
system = ActorSystem('bp')
|
|
||||||
ref = await system.spawn(
|
|
||||||
SlowActor,
|
|
||||||
'slow',
|
|
||||||
mailbox=MemoryMailbox(1, backpressure_policy=BACKPRESSURE_BLOCK),
|
|
||||||
)
|
|
||||||
|
|
||||||
for _ in range(10):
|
|
||||||
await ref.tell('inc')
|
|
||||||
|
|
||||||
await asyncio.sleep(0.25)
|
|
||||||
count = await ref.ask('get', timeout=2.0)
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
# Block policy should avoid dropping on tell path
|
|
||||||
assert count == 10
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from deerflow.actor import Actor, ActorSystem, IdempotentActorMixin, RetryEnvelope, ask_with_retry
|
|
||||||
|
|
||||||
|
|
||||||
class FlakyIdempotentActor(IdempotentActorMixin, Actor):
|
|
||||||
async def on_started(self):
|
|
||||||
self.calls = 0
|
|
||||||
|
|
||||||
async def on_receive(self, message):
|
|
||||||
return await self.handle_idempotent(message, self._handle)
|
|
||||||
|
|
||||||
async def _handle(self, payload):
|
|
||||||
self.calls += 1
|
|
||||||
if payload == 'flaky' and self.calls == 1:
|
|
||||||
await asyncio.sleep(0.02)
|
|
||||||
return 'late'
|
|
||||||
return f"ok:{payload}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_ask_with_retry_timeout_raises():
|
|
||||||
system = ActorSystem('retry')
|
|
||||||
ref = await system.spawn(FlakyIdempotentActor, 'a')
|
|
||||||
|
|
||||||
with pytest.raises(asyncio.TimeoutError):
|
|
||||||
await ask_with_retry(
|
|
||||||
ref,
|
|
||||||
'flaky',
|
|
||||||
timeout=0.005,
|
|
||||||
max_attempts=3,
|
|
||||||
base_backoff_s=0.001,
|
|
||||||
max_backoff_s=0.005,
|
|
||||||
jitter_ratio=0.0,
|
|
||||||
idempotency_key='k1',
|
|
||||||
)
|
|
||||||
|
|
||||||
# This helper retries timeout, but if each attempt times out it should raise.
|
|
||||||
assert ref.is_alive
|
|
||||||
await system.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_idempotent_envelope_returns_cached_result():
|
|
||||||
system = ActorSystem('retry')
|
|
||||||
ref = await system.spawn(FlakyIdempotentActor, 'a')
|
|
||||||
|
|
||||||
m1 = RetryEnvelope.wrap('x', idempotency_key='same-key')
|
|
||||||
m2 = RetryEnvelope.wrap('x', idempotency_key='same-key', attempt=2, max_attempts=3)
|
|
||||||
|
|
||||||
r1 = await ref.ask(m1, timeout=1.0)
|
|
||||||
r2 = await ref.ask(m2, timeout=1.0)
|
|
||||||
|
|
||||||
assert r1 == 'ok:x'
|
|
||||||
assert r2 == 'ok:x'
|
|
||||||
# handler should run once due to idempotency cache
|
|
||||||
actor = ref._cell.actor
|
|
||||||
assert actor.calls == 1
|
|
||||||
|
|
||||||
await system.shutdown()
|
|
||||||
@@ -0,0 +1,133 @@
|
|||||||
|
"""Tests for AioSandbox concurrent command serialization (#1433)."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def sandbox():
|
||||||
|
"""Create an AioSandbox with a mocked client."""
|
||||||
|
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
|
||||||
|
from deerflow.community.aio_sandbox.aio_sandbox import AioSandbox
|
||||||
|
|
||||||
|
sb = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
|
||||||
|
return sb
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecuteCommandSerialization:
|
||||||
|
"""Verify that concurrent exec_command calls are serialized."""
|
||||||
|
|
||||||
|
def test_lock_prevents_concurrent_execution(self, sandbox):
|
||||||
|
"""Concurrent threads should not overlap inside execute_command."""
|
||||||
|
call_log = []
|
||||||
|
barrier = threading.Barrier(3)
|
||||||
|
|
||||||
|
def slow_exec(command, **kwargs):
|
||||||
|
call_log.append(("enter", command))
|
||||||
|
import time
|
||||||
|
|
||||||
|
time.sleep(0.05)
|
||||||
|
call_log.append(("exit", command))
|
||||||
|
return SimpleNamespace(data=SimpleNamespace(output=f"ok: {command}"))
|
||||||
|
|
||||||
|
sandbox._client.shell.exec_command = slow_exec
|
||||||
|
|
||||||
|
def worker(cmd):
|
||||||
|
barrier.wait() # ensure all threads contend for the lock simultaneously
|
||||||
|
sandbox.execute_command(cmd)
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
for i in range(3):
|
||||||
|
t = threading.Thread(target=worker, args=(f"cmd-{i}",))
|
||||||
|
threads.append(t)
|
||||||
|
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# Verify serialization: each "enter" should be followed by its own
|
||||||
|
# "exit" before the next "enter" (no interleaving).
|
||||||
|
enters = [i for i, (action, _) in enumerate(call_log) if action == "enter"]
|
||||||
|
exits = [i for i, (action, _) in enumerate(call_log) if action == "exit"]
|
||||||
|
assert len(enters) == 3
|
||||||
|
assert len(exits) == 3
|
||||||
|
for e_idx, x_idx in zip(enters, exits):
|
||||||
|
assert x_idx == e_idx + 1, f"Interleaved execution detected: {call_log}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorObservationRetry:
|
||||||
|
"""Verify ErrorObservation detection and fresh-session retry."""
|
||||||
|
|
||||||
|
def test_retry_on_error_observation(self, sandbox):
|
||||||
|
"""When output contains ErrorObservation, retry with a fresh session."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def mock_exec(command, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
|
||||||
|
return SimpleNamespace(data=SimpleNamespace(output="success"))
|
||||||
|
|
||||||
|
sandbox._client.shell.exec_command = mock_exec
|
||||||
|
|
||||||
|
result = sandbox.execute_command("echo hello")
|
||||||
|
assert result == "success"
|
||||||
|
assert call_count == 2
|
||||||
|
|
||||||
|
def test_retry_passes_fresh_session_id(self, sandbox):
|
||||||
|
"""The retry call should include a new session id kwarg."""
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def mock_exec(command, **kwargs):
|
||||||
|
calls.append(kwargs)
|
||||||
|
if len(calls) == 1:
|
||||||
|
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
|
||||||
|
return SimpleNamespace(data=SimpleNamespace(output="ok"))
|
||||||
|
|
||||||
|
sandbox._client.shell.exec_command = mock_exec
|
||||||
|
|
||||||
|
sandbox.execute_command("test")
|
||||||
|
assert len(calls) == 2
|
||||||
|
assert "id" not in calls[0]
|
||||||
|
assert "id" in calls[1]
|
||||||
|
assert len(calls[1]["id"]) == 36 # UUID format
|
||||||
|
|
||||||
|
def test_no_retry_on_clean_output(self, sandbox):
|
||||||
|
"""Normal output should not trigger a retry."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def mock_exec(command, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return SimpleNamespace(data=SimpleNamespace(output="all good"))
|
||||||
|
|
||||||
|
sandbox._client.shell.exec_command = mock_exec
|
||||||
|
|
||||||
|
result = sandbox.execute_command("echo hello")
|
||||||
|
assert result == "all good"
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestListDirSerialization:
|
||||||
|
"""Verify that list_dir also acquires the lock."""
|
||||||
|
|
||||||
|
def test_list_dir_uses_lock(self, sandbox):
|
||||||
|
"""list_dir should hold the lock during execution."""
|
||||||
|
lock_was_held = []
|
||||||
|
|
||||||
|
original_exec = MagicMock(return_value=SimpleNamespace(data=SimpleNamespace(output="/a\n/b")))
|
||||||
|
|
||||||
|
def tracking_exec(command, **kwargs):
|
||||||
|
lock_was_held.append(sandbox._lock.locked())
|
||||||
|
return original_exec(command, **kwargs)
|
||||||
|
|
||||||
|
sandbox._client.shell.exec_command = tracking_exec
|
||||||
|
|
||||||
|
result = sandbox.list_dir("/test")
|
||||||
|
assert result == ["/a", "/b"]
|
||||||
|
assert lock_was_held == [True], "list_dir must hold the lock during exec_command"
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
from deerflow.community.aio_sandbox.local_backend import _format_container_mount
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_container_mount_uses_mount_syntax_for_docker_windows_paths():
|
||||||
|
args = _format_container_mount("docker", "D:/deer-flow/backend/.deer-flow/threads", "/mnt/threads", False)
|
||||||
|
|
||||||
|
assert args == [
|
||||||
|
"--mount",
|
||||||
|
"type=bind,src=D:/deer-flow/backend/.deer-flow/threads,dst=/mnt/threads",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_container_mount_marks_docker_readonly_mounts():
|
||||||
|
args = _format_container_mount("docker", "/host/path", "/mnt/path", True)
|
||||||
|
|
||||||
|
assert args == [
|
||||||
|
"--mount",
|
||||||
|
"type=bind,src=/host/path,dst=/mnt/path,readonly",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_container_mount_keeps_volume_syntax_for_apple_container():
|
||||||
|
args = _format_container_mount("container", "/host/path", "/mnt/path", True)
|
||||||
|
|
||||||
|
assert args == [
|
||||||
|
"-v",
|
||||||
|
"/host/path:/mnt/path:ro",
|
||||||
|
]
|
||||||
@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from deerflow.config.paths import Paths
|
from deerflow.config.paths import Paths, join_host_path
|
||||||
|
|
||||||
# ── ensure_thread_dirs ───────────────────────────────────────────────────────
|
# ── ensure_thread_dirs ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -31,6 +31,13 @@ def test_ensure_thread_dirs_acp_workspace_is_world_writable(tmp_path):
|
|||||||
assert mode == oct(0o777)
|
assert mode == oct(0o777)
|
||||||
|
|
||||||
|
|
||||||
|
def test_host_thread_dir_rejects_invalid_thread_id(tmp_path):
|
||||||
|
paths = Paths(base_dir=tmp_path)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid thread_id"):
|
||||||
|
paths.host_thread_dir("../escape")
|
||||||
|
|
||||||
|
|
||||||
# ── _get_thread_mounts ───────────────────────────────────────────────────────
|
# ── _get_thread_mounts ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -75,6 +82,30 @@ def test_get_thread_mounts_includes_user_data_dirs(tmp_path, monkeypatch):
|
|||||||
assert "/mnt/user-data/outputs" in container_paths
|
assert "/mnt/user-data/outputs" in container_paths
|
||||||
|
|
||||||
|
|
||||||
|
def test_join_host_path_preserves_windows_drive_letter_style():
|
||||||
|
base = r"C:\Users\demo\deer-flow\backend\.deer-flow"
|
||||||
|
|
||||||
|
joined = join_host_path(base, "threads", "thread-9", "user-data", "outputs")
|
||||||
|
|
||||||
|
assert joined == r"C:\Users\demo\deer-flow\backend\.deer-flow\threads\thread-9\user-data\outputs"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_thread_mounts_preserves_windows_host_path_style(tmp_path, monkeypatch):
|
||||||
|
"""Docker bind mount sources must keep Windows-style paths intact."""
|
||||||
|
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||||
|
monkeypatch.setenv("DEER_FLOW_HOST_BASE_DIR", r"C:\Users\demo\deer-flow\backend\.deer-flow")
|
||||||
|
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
||||||
|
|
||||||
|
mounts = aio_mod.AioSandboxProvider._get_thread_mounts("thread-10")
|
||||||
|
|
||||||
|
container_paths = {container_path: host_path for host_path, container_path, _ in mounts}
|
||||||
|
|
||||||
|
assert container_paths["/mnt/user-data/workspace"] == r"C:\Users\demo\deer-flow\backend\.deer-flow\threads\thread-10\user-data\workspace"
|
||||||
|
assert container_paths["/mnt/user-data/uploads"] == r"C:\Users\demo\deer-flow\backend\.deer-flow\threads\thread-10\user-data\uploads"
|
||||||
|
assert container_paths["/mnt/user-data/outputs"] == r"C:\Users\demo\deer-flow\backend\.deer-flow\threads\thread-10\user-data\outputs"
|
||||||
|
assert container_paths["/mnt/acp-workspace"] == r"C:\Users\demo\deer-flow\backend\.deer-flow\threads\thread-10\acp-workspace"
|
||||||
|
|
||||||
|
|
||||||
def test_discover_or_create_only_unlocks_when_lock_succeeds(tmp_path, monkeypatch):
|
def test_discover_or_create_only_unlocks_when_lock_succeeds(tmp_path, monkeypatch):
|
||||||
"""Unlock should not run if exclusive locking itself fails."""
|
"""Unlock should not run if exclusive locking itself fails."""
|
||||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||||
|
|||||||
@@ -11,9 +11,16 @@ import pytest
|
|||||||
|
|
||||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||||
SCRIPT_PATH = REPO_ROOT / "scripts" / "docker.sh"
|
SCRIPT_PATH = REPO_ROOT / "scripts" / "docker.sh"
|
||||||
BASH_EXECUTABLE = which("bash") or r"C:\Program Files\Git\bin\bash.exe"
|
BASH_CANDIDATES = [
|
||||||
|
Path(r"C:\Program Files\Git\bin\bash.exe"),
|
||||||
|
Path(which("bash")) if which("bash") else None,
|
||||||
|
]
|
||||||
|
BASH_EXECUTABLE = next(
|
||||||
|
(str(path) for path in BASH_CANDIDATES if path is not None and path.exists() and "WindowsApps" not in str(path)),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
if not Path(BASH_EXECUTABLE).exists():
|
if BASH_EXECUTABLE is None:
|
||||||
pytestmark = pytest.mark.skip(reason="bash is required for docker.sh detection tests")
|
pytestmark = pytest.mark.skip(reason="bash is required for docker.sh detection tests")
|
||||||
|
|
||||||
|
|
||||||
@@ -21,13 +28,14 @@ def _detect_mode_with_config(config_content: str) -> str:
|
|||||||
"""Write config content into a temp project root and execute detect_sandbox_mode."""
|
"""Write config content into a temp project root and execute detect_sandbox_mode."""
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
tmp_root = Path(tmpdir)
|
tmp_root = Path(tmpdir)
|
||||||
(tmp_root / "config.yaml").write_text(config_content)
|
(tmp_root / "config.yaml").write_text(config_content, encoding="utf-8")
|
||||||
|
|
||||||
command = f"source '{SCRIPT_PATH}' && PROJECT_ROOT='{tmp_root}' && detect_sandbox_mode"
|
command = f"source '{SCRIPT_PATH}' && PROJECT_ROOT='{tmp_root}' && detect_sandbox_mode"
|
||||||
|
|
||||||
output = subprocess.check_output(
|
output = subprocess.check_output(
|
||||||
[BASH_EXECUTABLE, "-lc", command],
|
[BASH_EXECUTABLE, "-lc", command],
|
||||||
text=True,
|
text=True,
|
||||||
|
encoding="utf-8",
|
||||||
).strip()
|
).strip()
|
||||||
|
|
||||||
return output
|
return output
|
||||||
@@ -37,7 +45,11 @@ def test_detect_mode_defaults_to_local_when_config_missing():
|
|||||||
"""No config file should default to local mode."""
|
"""No config file should default to local mode."""
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
command = f"source '{SCRIPT_PATH}' && PROJECT_ROOT='{tmpdir}' && detect_sandbox_mode"
|
command = f"source '{SCRIPT_PATH}' && PROJECT_ROOT='{tmpdir}' && detect_sandbox_mode"
|
||||||
output = subprocess.check_output([BASH_EXECUTABLE, "-lc", command], text=True).strip()
|
output = subprocess.check_output(
|
||||||
|
[BASH_EXECUTABLE, "-lc", command],
|
||||||
|
text=True,
|
||||||
|
encoding="utf-8",
|
||||||
|
).strip()
|
||||||
|
|
||||||
assert output == "local"
|
assert output == "local"
|
||||||
|
|
||||||
|
|||||||
@@ -100,3 +100,187 @@ def test_build_run_config_with_overrides():
|
|||||||
assert config["configurable"]["model_name"] == "gpt-4"
|
assert config["configurable"]["model_name"] == "gpt-4"
|
||||||
assert config["tags"] == ["test"]
|
assert config["tags"] == ["test"]
|
||||||
assert config["metadata"]["user"] == "alice"
|
assert config["metadata"]["user"] == "alice"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Regression tests for issue #1644:
|
||||||
|
# assistant_id not mapped to agent_name → custom agent SOUL.md never loaded
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_run_config_custom_agent_injects_agent_name():
|
||||||
|
"""Custom assistant_id must be forwarded as configurable['agent_name'].
|
||||||
|
|
||||||
|
Regression test for #1644: when the LangGraph Platform-compatible
|
||||||
|
/runs endpoint receives a custom assistant_id (e.g. 'finalis'), the
|
||||||
|
Gateway must inject configurable['agent_name'] so that make_lead_agent
|
||||||
|
loads the correct agents/finalis/SOUL.md.
|
||||||
|
"""
|
||||||
|
from app.gateway.services import build_run_config
|
||||||
|
|
||||||
|
config = build_run_config("thread-1", None, None, assistant_id="finalis")
|
||||||
|
assert config["configurable"]["agent_name"] == "finalis", "Custom assistant_id must be forwarded as configurable['agent_name'] so that make_lead_agent loads the correct SOUL.md"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_run_config_lead_agent_no_agent_name():
|
||||||
|
"""'lead_agent' assistant_id must NOT inject configurable['agent_name']."""
|
||||||
|
from app.gateway.services import build_run_config
|
||||||
|
|
||||||
|
config = build_run_config("thread-1", None, None, assistant_id="lead_agent")
|
||||||
|
assert "agent_name" not in config["configurable"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_run_config_none_assistant_id_no_agent_name():
|
||||||
|
"""None assistant_id must NOT inject configurable['agent_name']."""
|
||||||
|
from app.gateway.services import build_run_config
|
||||||
|
|
||||||
|
config = build_run_config("thread-1", None, None, assistant_id=None)
|
||||||
|
assert "agent_name" not in config["configurable"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_run_config_explicit_agent_name_not_overwritten():
|
||||||
|
"""An explicit configurable['agent_name'] in the request must take precedence."""
|
||||||
|
from app.gateway.services import build_run_config
|
||||||
|
|
||||||
|
config = build_run_config(
|
||||||
|
"thread-1",
|
||||||
|
{"configurable": {"agent_name": "explicit-agent"}},
|
||||||
|
None,
|
||||||
|
assistant_id="other-agent",
|
||||||
|
)
|
||||||
|
assert config["configurable"]["agent_name"] == "explicit-agent", "An explicit configurable['agent_name'] in the request body must not be overwritten by the assistant_id mapping"
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_agent_factory_returns_make_lead_agent():
|
||||||
|
"""resolve_agent_factory always returns make_lead_agent regardless of assistant_id."""
|
||||||
|
from app.gateway.services import resolve_agent_factory
|
||||||
|
from deerflow.agents.lead_agent.agent import make_lead_agent
|
||||||
|
|
||||||
|
assert resolve_agent_factory(None) is make_lead_agent
|
||||||
|
assert resolve_agent_factory("lead_agent") is make_lead_agent
|
||||||
|
assert resolve_agent_factory("finalis") is make_lead_agent
|
||||||
|
assert resolve_agent_factory("custom-agent-123") is make_lead_agent
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Regression tests for issue #1699:
|
||||||
|
# context field in langgraph-compat requests not merged into configurable
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_create_request_accepts_context():
|
||||||
|
"""RunCreateRequest must accept the ``context`` field without dropping it."""
|
||||||
|
from app.gateway.routers.thread_runs import RunCreateRequest
|
||||||
|
|
||||||
|
body = RunCreateRequest(
|
||||||
|
input={"messages": [{"role": "user", "content": "hi"}]},
|
||||||
|
context={
|
||||||
|
"model_name": "deepseek-v3",
|
||||||
|
"thinking_enabled": True,
|
||||||
|
"is_plan_mode": True,
|
||||||
|
"subagent_enabled": True,
|
||||||
|
"thread_id": "some-thread-id",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert body.context is not None
|
||||||
|
assert body.context["model_name"] == "deepseek-v3"
|
||||||
|
assert body.context["is_plan_mode"] is True
|
||||||
|
assert body.context["subagent_enabled"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_create_request_context_defaults_to_none():
|
||||||
|
"""RunCreateRequest without context should default to None (backward compat)."""
|
||||||
|
from app.gateway.routers.thread_runs import RunCreateRequest
|
||||||
|
|
||||||
|
body = RunCreateRequest(input=None)
|
||||||
|
assert body.context is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_merges_into_configurable():
|
||||||
|
"""Context values must be merged into config['configurable'] by start_run.
|
||||||
|
|
||||||
|
Since start_run is async and requires many dependencies, we test the
|
||||||
|
merging logic directly by simulating what start_run does.
|
||||||
|
"""
|
||||||
|
from app.gateway.services import build_run_config
|
||||||
|
|
||||||
|
# Simulate the context merging logic from start_run
|
||||||
|
config = build_run_config("thread-1", None, None)
|
||||||
|
|
||||||
|
context = {
|
||||||
|
"model_name": "deepseek-v3",
|
||||||
|
"mode": "ultra",
|
||||||
|
"reasoning_effort": "high",
|
||||||
|
"thinking_enabled": True,
|
||||||
|
"is_plan_mode": True,
|
||||||
|
"subagent_enabled": True,
|
||||||
|
"max_concurrent_subagents": 5,
|
||||||
|
"thread_id": "should-be-ignored",
|
||||||
|
}
|
||||||
|
|
||||||
|
_CONTEXT_CONFIGURABLE_KEYS = {
|
||||||
|
"model_name",
|
||||||
|
"mode",
|
||||||
|
"thinking_enabled",
|
||||||
|
"reasoning_effort",
|
||||||
|
"is_plan_mode",
|
||||||
|
"subagent_enabled",
|
||||||
|
"max_concurrent_subagents",
|
||||||
|
}
|
||||||
|
configurable = config.setdefault("configurable", {})
|
||||||
|
for key in _CONTEXT_CONFIGURABLE_KEYS:
|
||||||
|
if key in context:
|
||||||
|
configurable.setdefault(key, context[key])
|
||||||
|
|
||||||
|
assert config["configurable"]["model_name"] == "deepseek-v3"
|
||||||
|
assert config["configurable"]["thinking_enabled"] is True
|
||||||
|
assert config["configurable"]["is_plan_mode"] is True
|
||||||
|
assert config["configurable"]["subagent_enabled"] is True
|
||||||
|
assert config["configurable"]["max_concurrent_subagents"] == 5
|
||||||
|
assert config["configurable"]["reasoning_effort"] == "high"
|
||||||
|
assert config["configurable"]["mode"] == "ultra"
|
||||||
|
# thread_id from context should NOT override the one from build_run_config
|
||||||
|
assert config["configurable"]["thread_id"] == "thread-1"
|
||||||
|
# Non-allowlisted keys should not appear
|
||||||
|
assert "thread_id" not in {k for k in context if k in _CONTEXT_CONFIGURABLE_KEYS}
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_does_not_override_existing_configurable():
|
||||||
|
"""Values already in config.configurable must NOT be overridden by context.
|
||||||
|
|
||||||
|
This ensures that explicit configurable values from the ``config`` field
|
||||||
|
take precedence over the ``context`` field.
|
||||||
|
"""
|
||||||
|
from app.gateway.services import build_run_config
|
||||||
|
|
||||||
|
config = build_run_config(
|
||||||
|
"thread-1",
|
||||||
|
{"configurable": {"model_name": "gpt-4", "is_plan_mode": False}},
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
context = {
|
||||||
|
"model_name": "deepseek-v3",
|
||||||
|
"is_plan_mode": True,
|
||||||
|
"subagent_enabled": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
_CONTEXT_CONFIGURABLE_KEYS = {
|
||||||
|
"model_name",
|
||||||
|
"mode",
|
||||||
|
"thinking_enabled",
|
||||||
|
"reasoning_effort",
|
||||||
|
"is_plan_mode",
|
||||||
|
"subagent_enabled",
|
||||||
|
"max_concurrent_subagents",
|
||||||
|
}
|
||||||
|
configurable = config.setdefault("configurable", {})
|
||||||
|
for key in _CONTEXT_CONFIGURABLE_KEYS:
|
||||||
|
if key in context:
|
||||||
|
configurable.setdefault(key, context[key])
|
||||||
|
|
||||||
|
# Existing values must NOT be overridden
|
||||||
|
assert config["configurable"]["model_name"] == "gpt-4"
|
||||||
|
assert config["configurable"]["is_plan_mode"] is False
|
||||||
|
# New values should be added
|
||||||
|
assert config["configurable"]["subagent_enabled"] is True
|
||||||
|
|||||||
@@ -0,0 +1,46 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from deerflow.agents.lead_agent import prompt as prompt_module
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_custom_mounts_section_returns_empty_when_no_mounts(monkeypatch):
|
||||||
|
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=[]))
|
||||||
|
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||||
|
|
||||||
|
assert prompt_module._build_custom_mounts_section() == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_custom_mounts_section_lists_configured_mounts(monkeypatch):
|
||||||
|
mounts = [
|
||||||
|
SimpleNamespace(container_path="/home/user/shared", read_only=False),
|
||||||
|
SimpleNamespace(container_path="/mnt/reference", read_only=True),
|
||||||
|
]
|
||||||
|
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=mounts))
|
||||||
|
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||||
|
|
||||||
|
section = prompt_module._build_custom_mounts_section()
|
||||||
|
|
||||||
|
assert "**Custom Mounted Directories:**" in section
|
||||||
|
assert "`/home/user/shared`" in section
|
||||||
|
assert "read-write" in section
|
||||||
|
assert "`/mnt/reference`" in section
|
||||||
|
assert "read-only" in section
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_prompt_template_includes_custom_mounts(monkeypatch):
|
||||||
|
mounts = [SimpleNamespace(container_path="/home/user/shared", read_only=False)]
|
||||||
|
config = SimpleNamespace(
|
||||||
|
sandbox=SimpleNamespace(mounts=mounts),
|
||||||
|
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||||
|
monkeypatch.setattr(prompt_module, "load_skills", lambda enabled_only=True: [])
|
||||||
|
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda: "")
|
||||||
|
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
|
||||||
|
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
|
||||||
|
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
|
||||||
|
|
||||||
|
prompt = prompt_module.apply_prompt_template()
|
||||||
|
|
||||||
|
assert "`/home/user/shared`" in prompt
|
||||||
|
assert "Custom Mounted Directories" in prompt
|
||||||
@@ -1,83 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
redis = pytest.importorskip("redis.asyncio")
|
|
||||||
|
|
||||||
from deerflow.actor.mailbox_redis import RedisMailbox
|
|
||||||
from deerflow.actor.ref import _Envelope, _Stop
|
|
||||||
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.anyio
|
|
||||||
|
|
||||||
|
|
||||||
async def _make_mailbox(queue_name: str, *, maxlen: int = 0) -> RedisMailbox:
|
|
||||||
client = redis.Redis(host="127.0.0.1", port=6379, decode_responses=False)
|
|
||||||
await client.ping()
|
|
||||||
await client.delete(queue_name)
|
|
||||||
mailbox = RedisMailbox(client.connection_pool, queue_name, maxlen=maxlen, brpop_timeout=0.2)
|
|
||||||
return mailbox
|
|
||||||
|
|
||||||
|
|
||||||
async def test_roundtrip_envelope_and_stop():
|
|
||||||
queue = "deerflow:test:redis-mailbox:roundtrip"
|
|
||||||
mailbox = await _make_mailbox(queue)
|
|
||||||
try:
|
|
||||||
msg = _Envelope(payload={"k": "v"}, correlation_id="c1", reply_to="sysA")
|
|
||||||
ok = await mailbox.put(msg)
|
|
||||||
assert ok is True
|
|
||||||
|
|
||||||
got = await mailbox.get()
|
|
||||||
assert isinstance(got, _Envelope)
|
|
||||||
assert got.payload == {"k": "v"}
|
|
||||||
assert got.correlation_id == "c1"
|
|
||||||
assert got.reply_to == "sysA"
|
|
||||||
|
|
||||||
ok = await mailbox.put(_Stop())
|
|
||||||
assert ok is True
|
|
||||||
stop = await mailbox.get()
|
|
||||||
assert isinstance(stop, _Stop)
|
|
||||||
finally:
|
|
||||||
await mailbox.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_bounded_queue_rejects_when_full():
|
|
||||||
queue = "deerflow:test:redis-mailbox:bounded"
|
|
||||||
mailbox = await _make_mailbox(queue, maxlen=1)
|
|
||||||
try:
|
|
||||||
assert await mailbox.put(_Envelope("m1")) is True
|
|
||||||
assert await mailbox.put(_Envelope("m2")) is False
|
|
||||||
finally:
|
|
||||||
await mailbox.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_put_nowait_and_get_nowait_contract():
|
|
||||||
queue = "deerflow:test:redis-mailbox:nowait"
|
|
||||||
mailbox = await _make_mailbox(queue)
|
|
||||||
try:
|
|
||||||
assert mailbox.put_nowait(_Envelope("x")) is False
|
|
||||||
with pytest.raises(Exception, match="does not support synchronous get_nowait"):
|
|
||||||
mailbox.get_nowait()
|
|
||||||
finally:
|
|
||||||
await mailbox.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_system_enqueue_fallback_with_async_mailbox():
|
|
||||||
from deerflow.actor import Actor, ActorSystem
|
|
||||||
|
|
||||||
class EchoActor(Actor):
|
|
||||||
async def on_receive(self, message):
|
|
||||||
return message
|
|
||||||
|
|
||||||
queue = "deerflow:test:redis-mailbox:system-fallback"
|
|
||||||
mailbox = await _make_mailbox(queue)
|
|
||||||
|
|
||||||
system = ActorSystem("redis-test")
|
|
||||||
ref = await system.spawn(EchoActor, "echo", mailbox=mailbox)
|
|
||||||
try:
|
|
||||||
# This exercises _ActorCell.enqueue fallback path:
|
|
||||||
# put_nowait() -> False, then await put() -> True
|
|
||||||
result = await ref.ask("hello", timeout=3.0)
|
|
||||||
assert result == "hello"
|
|
||||||
finally:
|
|
||||||
await system.shutdown()
|
|
||||||
@@ -86,6 +86,18 @@ async def test_list_by_thread(manager: RunManager):
|
|||||||
assert runs[1].run_id == r1.run_id
|
assert runs[1].run_id == r1.run_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_list_by_thread_is_stable_when_timestamps_tie(manager: RunManager, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Newest-first ordering should not depend on timestamp precision."""
|
||||||
|
monkeypatch.setattr("deerflow.runtime.runs.manager._now_iso", lambda: "2026-01-01T00:00:00+00:00")
|
||||||
|
|
||||||
|
r1 = await manager.create("thread-1")
|
||||||
|
r2 = await manager.create("thread-1")
|
||||||
|
|
||||||
|
runs = await manager.list_by_thread("thread-1")
|
||||||
|
assert [run.run_id for run in runs] == [r2.run_id, r1.run_id]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_has_inflight(manager: RunManager):
|
async def test_has_inflight(manager: RunManager):
|
||||||
"""has_inflight should be True when a run is pending or running."""
|
"""has_inflight should be True when a run is pending or running."""
|
||||||
|
|||||||
Generated
+2262
-2273
File diff suppressed because it is too large
Load Diff
@@ -397,6 +397,9 @@ sandbox:
|
|||||||
# # - host_path: /path/on/host
|
# # - host_path: /path/on/host
|
||||||
# # container_path: /home/user/shared
|
# # container_path: /home/user/shared
|
||||||
# # read_only: false
|
# # read_only: false
|
||||||
|
# #
|
||||||
|
# # # DeerFlow will surface configured container_path values to the agent,
|
||||||
|
# # # so it can directly read/write mounted directories such as /home/user/shared
|
||||||
#
|
#
|
||||||
# # Optional: Environment variables to inject into the sandbox container
|
# # Optional: Environment variables to inject into the sandbox container
|
||||||
# # Values starting with $ will be resolved from host environment variables
|
# # Values starting with $ will be resolved from host environment variables
|
||||||
|
|||||||
@@ -121,8 +121,8 @@ services:
|
|||||||
container_name: deer-flow-langgraph
|
container_name: deer-flow-langgraph
|
||||||
command: sh -c "cd /app/backend && uv run langgraph dev --no-browser --allow-blocking --no-reload --host 0.0.0.0 --port 2024 --n-jobs-per-worker 10"
|
command: sh -c "cd /app/backend && uv run langgraph dev --no-browser --allow-blocking --no-reload --host 0.0.0.0 --port 2024 --n-jobs-per-worker 10"
|
||||||
volumes:
|
volumes:
|
||||||
- ${DEER_FLOW_CONFIG_PATH}:/app/config.yaml:ro
|
- ${DEER_FLOW_CONFIG_PATH}:/app/backend/config.yaml:ro
|
||||||
- ${DEER_FLOW_EXTENSIONS_CONFIG_PATH}:/app/extensions_config.json:ro
|
- ${DEER_FLOW_EXTENSIONS_CONFIG_PATH}:/app/backend/extensions_config.json:ro
|
||||||
- ${DEER_FLOW_HOME}:/app/backend/.deer-flow
|
- ${DEER_FLOW_HOME}:/app/backend/.deer-flow
|
||||||
- ../skills:/app/skills:ro
|
- ../skills:/app/skills:ro
|
||||||
- ../backend/.langgraph_api:/app/backend/.langgraph_api
|
- ../backend/.langgraph_api:/app/backend/.langgraph_api
|
||||||
@@ -144,14 +144,12 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
- CI=true
|
- CI=true
|
||||||
- DEER_FLOW_HOME=/app/backend/.deer-flow
|
- DEER_FLOW_HOME=/app/backend/.deer-flow
|
||||||
- DEER_FLOW_CONFIG_PATH=/app/config.yaml
|
- DEER_FLOW_CONFIG_PATH=/app/backend/config.yaml
|
||||||
- DEER_FLOW_EXTENSIONS_CONFIG_PATH=/app/extensions_config.json
|
- DEER_FLOW_EXTENSIONS_CONFIG_PATH=/app/backend/extensions_config.json
|
||||||
- DEER_FLOW_HOST_BASE_DIR=${DEER_FLOW_HOME}
|
- DEER_FLOW_HOST_BASE_DIR=${DEER_FLOW_HOME}
|
||||||
- DEER_FLOW_HOST_SKILLS_PATH=${DEER_FLOW_REPO_ROOT}/skills
|
- DEER_FLOW_HOST_SKILLS_PATH=${DEER_FLOW_REPO_ROOT}/skills
|
||||||
- DEER_FLOW_SANDBOX_HOST=host.docker.internal
|
- DEER_FLOW_SANDBOX_HOST=host.docker.internal
|
||||||
# Disable LangSmith tracing — LANGSMITH_API_KEY is not required.
|
# LangSmith tracing: set LANGSMITH_TRACING=true and LANGSMITH_API_KEY in .env to enable.
|
||||||
# Set LANGSMITH_TRACING=true and LANGSMITH_API_KEY in .env to enable.
|
|
||||||
- LANGSMITH_TRACING=${LANGSMITH_TRACING:-false}
|
|
||||||
env_file:
|
env_file:
|
||||||
- ../.env
|
- ../.env
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
@@ -39,7 +40,7 @@ from fastapi import FastAPI, HTTPException
|
|||||||
from kubernetes import client as k8s_client
|
from kubernetes import client as k8s_client
|
||||||
from kubernetes import config as k8s_config
|
from kubernetes import config as k8s_config
|
||||||
from kubernetes.client.rest import ApiException
|
from kubernetes.client.rest import ApiException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
# Suppress only the InsecureRequestWarning from urllib3
|
# Suppress only the InsecureRequestWarning from urllib3
|
||||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
@@ -59,6 +60,7 @@ SANDBOX_IMAGE = os.environ.get(
|
|||||||
)
|
)
|
||||||
SKILLS_HOST_PATH = os.environ.get("SKILLS_HOST_PATH", "/skills")
|
SKILLS_HOST_PATH = os.environ.get("SKILLS_HOST_PATH", "/skills")
|
||||||
THREADS_HOST_PATH = os.environ.get("THREADS_HOST_PATH", "/.deer-flow/threads")
|
THREADS_HOST_PATH = os.environ.get("THREADS_HOST_PATH", "/.deer-flow/threads")
|
||||||
|
SAFE_THREAD_ID_PATTERN = r"^[A-Za-z0-9_\-]+$"
|
||||||
|
|
||||||
# Path to the kubeconfig *inside* the provisioner container.
|
# Path to the kubeconfig *inside* the provisioner container.
|
||||||
# Typically the host's ~/.kube/config is mounted here.
|
# Typically the host's ~/.kube/config is mounted here.
|
||||||
@@ -69,6 +71,36 @@ KUBECONFIG_PATH = os.environ.get("KUBECONFIG_PATH", "/root/.kube/config")
|
|||||||
# is ``host.docker.internal``; on Linux it may be the host's LAN IP.
|
# is ``host.docker.internal``; on Linux it may be the host's LAN IP.
|
||||||
NODE_HOST = os.environ.get("NODE_HOST", "host.docker.internal")
|
NODE_HOST = os.environ.get("NODE_HOST", "host.docker.internal")
|
||||||
|
|
||||||
|
|
||||||
|
def join_host_path(base: str, *parts: str) -> str:
|
||||||
|
"""Join host filesystem path segments while preserving native style."""
|
||||||
|
if not parts:
|
||||||
|
return base
|
||||||
|
|
||||||
|
if re.match(r"^[A-Za-z]:[\\/]", base) or base.startswith("\\\\") or "\\" in base:
|
||||||
|
from pathlib import PureWindowsPath
|
||||||
|
|
||||||
|
result = PureWindowsPath(base)
|
||||||
|
for part in parts:
|
||||||
|
result /= part
|
||||||
|
return str(result)
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
result = Path(base)
|
||||||
|
for part in parts:
|
||||||
|
result /= part
|
||||||
|
return str(result)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_thread_id(thread_id: str) -> str:
|
||||||
|
if not re.match(SAFE_THREAD_ID_PATTERN, thread_id):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid thread_id: only alphanumeric characters, hyphens, and underscores are allowed."
|
||||||
|
)
|
||||||
|
return thread_id
|
||||||
|
|
||||||
|
|
||||||
# ── K8s client setup ────────────────────────────────────────────────────
|
# ── K8s client setup ────────────────────────────────────────────────────
|
||||||
|
|
||||||
core_v1: k8s_client.CoreV1Api | None = None
|
core_v1: k8s_client.CoreV1Api | None = None
|
||||||
@@ -186,7 +218,7 @@ app = FastAPI(title="DeerFlow Sandbox Provisioner", lifespan=lifespan)
|
|||||||
|
|
||||||
class CreateSandboxRequest(BaseModel):
|
class CreateSandboxRequest(BaseModel):
|
||||||
sandbox_id: str
|
sandbox_id: str
|
||||||
thread_id: str
|
thread_id: str = Field(pattern=SAFE_THREAD_ID_PATTERN)
|
||||||
|
|
||||||
|
|
||||||
class SandboxResponse(BaseModel):
|
class SandboxResponse(BaseModel):
|
||||||
@@ -213,6 +245,7 @@ def _sandbox_url(node_port: int) -> str:
|
|||||||
|
|
||||||
def _build_pod(sandbox_id: str, thread_id: str) -> k8s_client.V1Pod:
|
def _build_pod(sandbox_id: str, thread_id: str) -> k8s_client.V1Pod:
|
||||||
"""Construct a Pod manifest for a single sandbox."""
|
"""Construct a Pod manifest for a single sandbox."""
|
||||||
|
thread_id = _validate_thread_id(thread_id)
|
||||||
return k8s_client.V1Pod(
|
return k8s_client.V1Pod(
|
||||||
metadata=k8s_client.V1ObjectMeta(
|
metadata=k8s_client.V1ObjectMeta(
|
||||||
name=_pod_name(sandbox_id),
|
name=_pod_name(sandbox_id),
|
||||||
@@ -298,7 +331,7 @@ def _build_pod(sandbox_id: str, thread_id: str) -> k8s_client.V1Pod:
|
|||||||
k8s_client.V1Volume(
|
k8s_client.V1Volume(
|
||||||
name="user-data",
|
name="user-data",
|
||||||
host_path=k8s_client.V1HostPathVolumeSource(
|
host_path=k8s_client.V1HostPathVolumeSource(
|
||||||
path=f"{THREADS_HOST_PATH}/{thread_id}/user-data",
|
path=join_host_path(THREADS_HOST_PATH, thread_id, "user-data"),
|
||||||
type="DirectoryOrCreate",
|
type="DirectoryOrCreate",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -80,13 +80,9 @@ export default function NewAgentPage() {
|
|||||||
setNameError(t.agents.nameStepAlreadyExistsError);
|
setNameError(t.agents.nameStepAlreadyExistsError);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (err) {
|
||||||
if (error instanceof AgentNameCheckError) {
|
if (err instanceof TypeError && err.message === "Failed to fetch") {
|
||||||
setNameError(
|
setNameError(t.agents.nameStepNetworkError);
|
||||||
error.reason === "backend_unreachable"
|
|
||||||
? t.agents.nameStepCheckError
|
|
||||||
: error.message,
|
|
||||||
);
|
|
||||||
} else {
|
} else {
|
||||||
setNameError(t.agents.nameStepCheckError);
|
setNameError(t.agents.nameStepCheckError);
|
||||||
}
|
}
|
||||||
@@ -107,6 +103,7 @@ export default function NewAgentPage() {
|
|||||||
t.agents.nameStepBootstrapMessage,
|
t.agents.nameStepBootstrapMessage,
|
||||||
t.agents.nameStepInvalidError,
|
t.agents.nameStepInvalidError,
|
||||||
t.agents.nameStepAlreadyExistsError,
|
t.agents.nameStepAlreadyExistsError,
|
||||||
|
t.agents.nameStepNetworkError,
|
||||||
t.agents.nameStepCheckError,
|
t.agents.nameStepCheckError,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ export function ArtifactFileDetail({
|
|||||||
const isSupportPreview = useMemo(() => {
|
const isSupportPreview = useMemo(() => {
|
||||||
return language === "html" || language === "markdown";
|
return language === "html" || language === "markdown";
|
||||||
}, [language]);
|
}, [language]);
|
||||||
const { content } = useArtifactContent({
|
const { content, url } = useArtifactContent({
|
||||||
threadId,
|
threadId,
|
||||||
filepath: filepathFromProps,
|
filepath: filepathFromProps,
|
||||||
enabled: isCodeFile && !isWriteFile,
|
enabled: isCodeFile && !isWriteFile,
|
||||||
@@ -240,7 +240,9 @@ export function ArtifactFileDetail({
|
|||||||
(language === "markdown" || language === "html") && (
|
(language === "markdown" || language === "html") && (
|
||||||
<ArtifactFilePreview
|
<ArtifactFilePreview
|
||||||
content={displayContent}
|
content={displayContent}
|
||||||
|
isWriteFile={isWriteFile}
|
||||||
language={language ?? "text"}
|
language={language ?? "text"}
|
||||||
|
url={url}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{isCodeFile && viewMode === "code" && (
|
{isCodeFile && viewMode === "code" && (
|
||||||
@@ -263,10 +265,14 @@ export function ArtifactFileDetail({
|
|||||||
|
|
||||||
export function ArtifactFilePreview({
|
export function ArtifactFilePreview({
|
||||||
content,
|
content,
|
||||||
|
isWriteFile,
|
||||||
language,
|
language,
|
||||||
|
url,
|
||||||
}: {
|
}: {
|
||||||
content: string;
|
content: string;
|
||||||
|
isWriteFile: boolean;
|
||||||
language: string;
|
language: string;
|
||||||
|
url?: string;
|
||||||
}) {
|
}) {
|
||||||
if (language === "markdown") {
|
if (language === "markdown") {
|
||||||
return (
|
return (
|
||||||
@@ -286,8 +292,8 @@ export function ArtifactFilePreview({
|
|||||||
<iframe
|
<iframe
|
||||||
className="size-full"
|
className="size-full"
|
||||||
title="Artifact preview"
|
title="Artifact preview"
|
||||||
srcDoc={content}
|
|
||||||
sandbox="allow-scripts allow-forms"
|
sandbox="allow-scripts allow-forms"
|
||||||
|
{...(isWriteFile ? { srcDoc: content } : url ? { src: url } : {})}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,5 +34,10 @@ export function useArtifactContent({
|
|||||||
// Cache artifact content for 5 minutes to avoid repeated fetches (especially for .skill ZIP extraction)
|
// Cache artifact content for 5 minutes to avoid repeated fetches (especially for .skill ZIP extraction)
|
||||||
staleTime: 5 * 60 * 1000,
|
staleTime: 5 * 60 * 1000,
|
||||||
});
|
});
|
||||||
return { content: isWriteFile ? content : data, isLoading, error };
|
return {
|
||||||
|
content: isWriteFile ? content : data?.content,
|
||||||
|
url: isWriteFile ? undefined : data?.url,
|
||||||
|
isLoading,
|
||||||
|
error,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ export async function loadArtifactContent({
|
|||||||
const url = urlOfArtifact({ filepath: enhancedFilepath, threadId, isMock });
|
const url = urlOfArtifact({ filepath: enhancedFilepath, threadId, isMock });
|
||||||
const response = await fetch(url);
|
const response = await fetch(url);
|
||||||
const text = await response.text();
|
const text = await response.text();
|
||||||
return text;
|
return { content: text, url };
|
||||||
}
|
}
|
||||||
|
|
||||||
export function loadArtifactContentFromToolCall({
|
export function loadArtifactContentFromToolCall({
|
||||||
|
|||||||
@@ -194,8 +194,9 @@ export const enUS: Translations = {
|
|||||||
nameStepInvalidError:
|
nameStepInvalidError:
|
||||||
"Invalid name — use only letters, digits, and hyphens",
|
"Invalid name — use only letters, digits, and hyphens",
|
||||||
nameStepAlreadyExistsError: "An agent with this name already exists",
|
nameStepAlreadyExistsError: "An agent with this name already exists",
|
||||||
nameStepCheckError:
|
nameStepNetworkError:
|
||||||
"Could not reach the DeerFlow backend to verify name availability. Start the backend or set NEXT_PUBLIC_BACKEND_BASE_URL, then try again.",
|
"Network request failed — check your network or backend connection",
|
||||||
|
nameStepCheckError: "Could not verify name availability — please try again",
|
||||||
nameStepBootstrapMessage:
|
nameStepBootstrapMessage:
|
||||||
"The new custom agent name is {name}. Let's bootstrap it's **SOUL**.",
|
"The new custom agent name is {name}. Let's bootstrap it's **SOUL**.",
|
||||||
agentCreated: "Agent created!",
|
agentCreated: "Agent created!",
|
||||||
|
|||||||
@@ -133,6 +133,7 @@ export interface Translations {
|
|||||||
nameStepContinue: string;
|
nameStepContinue: string;
|
||||||
nameStepInvalidError: string;
|
nameStepInvalidError: string;
|
||||||
nameStepAlreadyExistsError: string;
|
nameStepAlreadyExistsError: string;
|
||||||
|
nameStepNetworkError: string;
|
||||||
nameStepCheckError: string;
|
nameStepCheckError: string;
|
||||||
nameStepBootstrapMessage: string;
|
nameStepBootstrapMessage: string;
|
||||||
agentCreated: string;
|
agentCreated: string;
|
||||||
|
|||||||
@@ -183,8 +183,8 @@ export const zhCN: Translations = {
|
|||||||
nameStepContinue: "继续",
|
nameStepContinue: "继续",
|
||||||
nameStepInvalidError: "名称无效,只允许字母、数字和连字符",
|
nameStepInvalidError: "名称无效,只允许字母、数字和连字符",
|
||||||
nameStepAlreadyExistsError: "已存在同名智能体",
|
nameStepAlreadyExistsError: "已存在同名智能体",
|
||||||
nameStepCheckError:
|
nameStepNetworkError: "网络请求失败,请检查网络或后端连接",
|
||||||
"无法连接 DeerFlow 后端来验证名称是否可用。请先启动后端,或配置 NEXT_PUBLIC_BACKEND_BASE_URL,然后再重试。",
|
nameStepCheckError: "无法验证名称可用性,请稍后重试",
|
||||||
nameStepBootstrapMessage:
|
nameStepBootstrapMessage:
|
||||||
"新智能体的名称是 {name},现在开始为它生成 **SOUL**。",
|
"新智能体的名称是 {name},现在开始为它生成 **SOUL**。",
|
||||||
agentCreated: "智能体已创建!",
|
agentCreated: "智能体已创建!",
|
||||||
|
|||||||
+15
-2
@@ -9,6 +9,17 @@ import sys
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def configure_stdio() -> None:
|
||||||
|
"""Prefer UTF-8 output so Unicode status markers render on Windows."""
|
||||||
|
for stream_name in ("stdout", "stderr"):
|
||||||
|
stream = getattr(sys, stream_name, None)
|
||||||
|
if hasattr(stream, "reconfigure"):
|
||||||
|
try:
|
||||||
|
stream.reconfigure(encoding="utf-8", errors="replace")
|
||||||
|
except (OSError, ValueError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
def run_command(command: list[str]) -> Optional[str]:
|
def run_command(command: list[str]) -> Optional[str]:
|
||||||
"""Run a command and return trimmed stdout, or None on failure."""
|
"""Run a command and return trimmed stdout, or None on failure."""
|
||||||
try:
|
try:
|
||||||
@@ -29,6 +40,7 @@ def parse_node_major(version_text: str) -> Optional[int]:
|
|||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
def main() -> int:
|
||||||
|
configure_stdio()
|
||||||
print("==========================================")
|
print("==========================================")
|
||||||
print(" Checking Required Dependencies")
|
print(" Checking Required Dependencies")
|
||||||
print("==========================================")
|
print("==========================================")
|
||||||
@@ -61,8 +73,9 @@ def main() -> int:
|
|||||||
|
|
||||||
print()
|
print()
|
||||||
print("Checking pnpm...")
|
print("Checking pnpm...")
|
||||||
if shutil.which("pnpm"):
|
pnpm_executable = shutil.which("pnpm.cmd") or shutil.which("pnpm")
|
||||||
pnpm_version = run_command(["pnpm", "-v"])
|
if pnpm_executable:
|
||||||
|
pnpm_version = run_command([pnpm_executable, "-v"])
|
||||||
if pnpm_version:
|
if pnpm_version:
|
||||||
print(f" ✓ pnpm {pnpm_version}")
|
print(f" ✓ pnpm {pnpm_version}")
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user