fix(sandbox): avoid blocking sandbox readiness polling (#2822)

* fix(sandbox): offload async sandbox acquisition

Run blocking sandbox provider acquisition through the async provider hook so eager sandbox setup does not stall the event loop.

* fix(sandbox): add async readiness polling

Introduce an async sandbox readiness poller using httpx and asyncio.sleep while preserving the existing synchronous API.

* test(sandbox): cover async readiness polling

Lock in non-blocking readiness behavior so the async helper does not regress to requests.get or time.sleep.

* fix(sandbox): allow anonymous backend creation

* fix(sandbox): use async readiness in provider acquisition

* fix(sandbox): use async acquisition for lazy tools

* test(sandbox): cover anonymous remote creation

* fix(sandbox): clamp async readiness timeout budget

* fix(sandbox): offload async lock file handling

* fix(sandbox): delegate async middleware fallthrough

* docs(sandbox): document async acquisition path

* fix(sandbox): offload async sandbox release

* docs(sandbox): mention async release hook

* fix(sandbox): address async lock review

Reduce duplicate sync/async sandbox acquisition state handling and move async thread-lock waits onto a dedicated executor with cancellation-safe cleanup.

* chore: retrigger ci

Retrigger GitHub Actions after upstream main fixed the stale PR merge lint failure.

* test(sandbox): sync backend unit fixtures

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
AochenShen99
2026-05-21 14:44:34 +08:00
committed by GitHub
parent dcc6f1e678
commit 8b697245eb
13 changed files with 1037 additions and 73 deletions
+1 -1
View File
@@ -236,7 +236,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
### Sandbox System (`packages/harness/deerflow/sandbox/`) ### Sandbox System (`packages/harness/deerflow/sandbox/`)
**Interface**: Abstract `Sandbox` with `execute_command`, `read_file`, `write_file`, `list_dir` **Interface**: Abstract `Sandbox` with `execute_command`, `read_file`, `write_file`, `list_dir`
**Provider Pattern**: `SandboxProvider` with `acquire`, `get`, `release` lifecycle **Provider Pattern**: `SandboxProvider` with `acquire`, `acquire_async`, `get`, `release` lifecycle. Async agent/tool paths call async sandbox lifecycle hooks so Docker sandbox creation, discovery, cross-process locking, readiness polling, and release stay off the event loop.
**Implementations**: **Implementations**:
- `LocalSandboxProvider` - Local filesystem execution. `acquire(thread_id)` returns a per-thread `LocalSandbox` (id `local:{thread_id}`) whose `path_mappings` resolve `/mnt/user-data/{workspace,uploads,outputs}` and `/mnt/acp-workspace` to that thread's host directories, so the public `Sandbox` API honours the `/mnt/user-data` contract uniformly with AIO. `acquire()` / `acquire(None)` keeps the legacy generic singleton (id `local`) for callers without a thread context. Per-thread sandboxes are held in an LRU cache (default 256 entries) guarded by a `threading.Lock`. - `LocalSandboxProvider` - Local filesystem execution. `acquire(thread_id)` returns a per-thread `LocalSandbox` (id `local:{thread_id}`) whose `path_mappings` resolve `/mnt/user-data/{workspace,uploads,outputs}` and `/mnt/acp-workspace` to that thread's host directories, so the public `Sandbox` API honours the `/mnt/user-data` contract uniformly with AIO. `acquire()` / `acquire(None)` keeps the legacy generic singleton (id `local`) for callers without a thread context. Per-thread sandboxes are held in an LRU cache (default 256 entries) guarded by a `threading.Lock`.
- `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation - `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation
+1 -1
View File
@@ -69,7 +69,7 @@ Middlewares execute in strict order, each handling a specific concern:
Per-thread isolated execution with virtual path translation: Per-thread isolated execution with virtual path translation:
- **Abstract interface**: `execute_command`, `read_file`, `write_file`, `list_dir` - **Abstract interface**: `execute_command`, `read_file`, `write_file`, `list_dir`
- **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/) - **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/). Async runtime paths use async sandbox lifecycle hooks so startup, readiness polling, and release do not block the event loop.
- **Virtual paths**: `/mnt/user-data/{workspace,uploads,outputs}` → thread-specific physical directories - **Virtual paths**: `/mnt/user-data/{workspace,uploads,outputs}` → thread-specific physical directories
- **Skills path**: `/mnt/skills``deer-flow/skills/` directory - **Skills path**: `/mnt/skills``deer-flow/skills/` directory
- **Skills loading**: Recursively discovers nested `SKILL.md` files under `skills/{public,custom}` and preserves nested container paths - **Skills loading**: Recursively discovers nested `SKILL.md` files under `skills/{public,custom}` and preserves nested container paths
@@ -10,6 +10,7 @@ The provider itself handles:
- Mount computation (thread-specific, skills) - Mount computation (thread-specific, skills)
""" """
import asyncio
import atexit import atexit
import hashlib import hashlib
import logging import logging
@@ -18,6 +19,7 @@ import signal
import threading import threading
import time import time
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor
try: try:
import fcntl import fcntl
@@ -32,7 +34,7 @@ from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import SandboxProvider from deerflow.sandbox.sandbox_provider import SandboxProvider
from .aio_sandbox import AioSandbox from .aio_sandbox import AioSandbox
from .backend import SandboxBackend, wait_for_sandbox_ready from .backend import SandboxBackend, wait_for_sandbox_ready, wait_for_sandbox_ready_async
from .local_backend import LocalContainerBackend from .local_backend import LocalContainerBackend
from .remote_backend import RemoteSandboxBackend from .remote_backend import RemoteSandboxBackend
from .sandbox_info import SandboxInfo from .sandbox_info import SandboxInfo
@@ -46,6 +48,9 @@ DEFAULT_CONTAINER_PREFIX = "deer-flow-sandbox"
DEFAULT_IDLE_TIMEOUT = 600 # 10 minutes in seconds DEFAULT_IDLE_TIMEOUT = 600 # 10 minutes in seconds
DEFAULT_REPLICAS = 3 # Maximum concurrent sandbox containers DEFAULT_REPLICAS = 3 # Maximum concurrent sandbox containers
IDLE_CHECK_INTERVAL = 60 # Check every 60 seconds IDLE_CHECK_INTERVAL = 60 # Check every 60 seconds
THREAD_LOCK_EXECUTOR_WORKERS = min(32, (os.cpu_count() or 1) + 4)
_THREAD_LOCK_EXECUTOR = ThreadPoolExecutor(max_workers=THREAD_LOCK_EXECUTOR_WORKERS, thread_name_prefix="sandbox-lock-wait")
atexit.register(_THREAD_LOCK_EXECUTOR.shutdown, wait=False, cancel_futures=True)
def _lock_file_exclusive(lock_file) -> None: def _lock_file_exclusive(lock_file) -> None:
@@ -66,6 +71,40 @@ def _unlock_file(lock_file) -> None:
msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1) msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1)
def _open_lock_file(lock_path):
return open(lock_path, "a", encoding="utf-8")
async def _acquire_thread_lock_async(lock: threading.Lock) -> None:
"""Acquire a threading.Lock without polling or using the default executor."""
loop = asyncio.get_running_loop()
acquire_future = loop.run_in_executor(_THREAD_LOCK_EXECUTOR, lock.acquire, True)
try:
acquired = await asyncio.shield(acquire_future)
except asyncio.CancelledError:
acquire_future.add_done_callback(lambda task: _release_cancelled_lock_acquire(lock, task))
raise
if not acquired:
raise RuntimeError("Failed to acquire sandbox thread lock")
def _release_cancelled_lock_acquire(lock: threading.Lock, task: asyncio.Future[bool]) -> None:
"""Release a lock acquired after its awaiting coroutine was cancelled."""
if task.cancelled():
return
try:
acquired = task.result()
except Exception as e:
logger.warning(f"Cancelled sandbox lock acquisition finished with error: {e}")
return
if acquired:
lock.release()
class AioSandboxProvider(SandboxProvider): class AioSandboxProvider(SandboxProvider):
"""Sandbox provider that manages containers running the AIO sandbox. """Sandbox provider that manages containers running the AIO sandbox.
@@ -416,6 +455,96 @@ class AioSandboxProvider(SandboxProvider):
self._thread_locks[thread_id] = threading.Lock() self._thread_locks[thread_id] = threading.Lock()
return self._thread_locks[thread_id] return self._thread_locks[thread_id]
def _sandbox_id_for_thread(self, thread_id: str | None) -> str:
"""Return deterministic IDs for thread sandboxes and random IDs otherwise."""
return self._deterministic_sandbox_id(thread_id) if thread_id else str(uuid.uuid4())[:8]
def _reuse_in_process_sandbox(self, thread_id: str | None, *, post_lock: bool = False) -> str | None:
"""Reuse an active in-process sandbox for a thread if one is still tracked."""
if thread_id is None:
return None
with self._lock:
if thread_id not in self._thread_sandboxes:
return None
existing_id = self._thread_sandboxes[thread_id]
if existing_id in self._sandboxes:
suffix = " (post-lock check)" if post_lock else ""
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}{suffix}")
self._last_activity[existing_id] = time.time()
return existing_id
del self._thread_sandboxes[thread_id]
return None
def _reclaim_warm_pool_sandbox(self, thread_id: str | None, sandbox_id: str, *, post_lock: bool = False) -> str | None:
"""Promote a warm-pool sandbox back to active tracking if available."""
if thread_id is None:
return None
with self._lock:
if sandbox_id not in self._warm_pool:
return None
info, _ = self._warm_pool.pop(sandbox_id)
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
self._sandboxes[sandbox_id] = sandbox
self._sandbox_infos[sandbox_id] = info
self._last_activity[sandbox_id] = time.time()
self._thread_sandboxes[thread_id] = sandbox_id
suffix = " (post-lock check)" if post_lock else f" at {info.sandbox_url}"
logger.info(f"Reclaimed warm-pool sandbox {sandbox_id} for thread {thread_id}{suffix}")
return sandbox_id
def _recheck_cached_sandbox(self, thread_id: str, sandbox_id: str) -> str | None:
"""Re-check in-memory caches after acquiring the cross-process file lock."""
return self._reuse_in_process_sandbox(thread_id, post_lock=True) or self._reclaim_warm_pool_sandbox(thread_id, sandbox_id, post_lock=True)
def _register_discovered_sandbox(self, thread_id: str, info: SandboxInfo) -> str:
"""Track a sandbox discovered through the backend."""
sandbox = AioSandbox(id=info.sandbox_id, base_url=info.sandbox_url)
with self._lock:
self._sandboxes[info.sandbox_id] = sandbox
self._sandbox_infos[info.sandbox_id] = info
self._last_activity[info.sandbox_id] = time.time()
self._thread_sandboxes[thread_id] = info.sandbox_id
logger.info(f"Discovered existing sandbox {info.sandbox_id} for thread {thread_id} at {info.sandbox_url}")
return info.sandbox_id
def _register_created_sandbox(self, thread_id: str | None, sandbox_id: str, info: SandboxInfo) -> str:
"""Track a newly-created sandbox in the active maps."""
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
with self._lock:
self._sandboxes[sandbox_id] = sandbox
self._sandbox_infos[sandbox_id] = info
self._last_activity[sandbox_id] = time.time()
if thread_id:
self._thread_sandboxes[thread_id] = sandbox_id
logger.info(f"Created sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}")
return sandbox_id
def _replica_count(self) -> tuple[int, int]:
"""Return configured replicas and currently tracked sandbox count."""
replicas = self._config.get("replicas", DEFAULT_REPLICAS)
with self._lock:
total = len(self._sandboxes) + len(self._warm_pool)
return replicas, total
def _log_replicas_soft_cap(self, replicas: int, sandbox_id: str, evicted: str | None) -> None:
"""Log the result of enforcing the warm-pool replica budget."""
if evicted:
logger.info(f"Evicted warm-pool sandbox {evicted} to stay within replicas={replicas}")
return
# All slots are occupied by active sandboxes — proceed anyway and log.
# The replicas limit is a soft cap; we never forcibly stop a container
# that is actively serving a thread.
logger.warning(f"All {replicas} replica slots are in active use; creating sandbox {sandbox_id} beyond the soft limit")
# ── Core: acquire / get / release / shutdown ───────────────────────── # ── Core: acquire / get / release / shutdown ─────────────────────────
def acquire(self, thread_id: str | None = None) -> str: def acquire(self, thread_id: str | None = None) -> str:
@@ -440,6 +569,23 @@ class AioSandboxProvider(SandboxProvider):
else: else:
return self._acquire_internal(thread_id) return self._acquire_internal(thread_id)
async def acquire_async(self, thread_id: str | None = None) -> str:
"""Acquire a sandbox environment without blocking the event loop.
Mirrors ``acquire()`` while keeping blocking backend operations off the
event loop and using async-native readiness polling for newly created
sandboxes.
"""
if thread_id:
thread_lock = self._get_thread_lock(thread_id)
await _acquire_thread_lock_async(thread_lock)
try:
return await self._acquire_internal_async(thread_id)
finally:
thread_lock.release()
return await self._acquire_internal_async(thread_id)
def _acquire_internal(self, thread_id: str | None) -> str: def _acquire_internal(self, thread_id: str | None) -> str:
"""Internal sandbox acquisition with two-layer consistency. """Internal sandbox acquisition with two-layer consistency.
@@ -448,33 +594,17 @@ class AioSandboxProvider(SandboxProvider):
sandbox_id is deterministic from thread_id so no shared state file sandbox_id is deterministic from thread_id so no shared state file
is needed — any process can derive the same container name) is needed — any process can derive the same container name)
""" """
# ── Layer 1: In-process cache (fast path) ── cached_id = self._reuse_in_process_sandbox(thread_id)
if thread_id: if cached_id is not None:
with self._lock: return cached_id
if thread_id in self._thread_sandboxes:
existing_id = self._thread_sandboxes[thread_id]
if existing_id in self._sandboxes:
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}")
self._last_activity[existing_id] = time.time()
return existing_id
else:
del self._thread_sandboxes[thread_id]
# Deterministic ID for thread-specific, random for anonymous # Deterministic ID for thread-specific, random for anonymous
sandbox_id = self._deterministic_sandbox_id(thread_id) if thread_id else str(uuid.uuid4())[:8] sandbox_id = self._sandbox_id_for_thread(thread_id)
# ── Layer 1.5: Warm pool (container still running, no cold-start) ── # ── Layer 1.5: Warm pool (container still running, no cold-start) ──
if thread_id: reclaimed_id = self._reclaim_warm_pool_sandbox(thread_id, sandbox_id)
with self._lock: if reclaimed_id is not None:
if sandbox_id in self._warm_pool: return reclaimed_id
info, _ = self._warm_pool.pop(sandbox_id)
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
self._sandboxes[sandbox_id] = sandbox
self._sandbox_infos[sandbox_id] = info
self._last_activity[sandbox_id] = time.time()
self._thread_sandboxes[thread_id] = sandbox_id
logger.info(f"Reclaimed warm-pool sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}")
return sandbox_id
# ── Layer 2: Backend discovery + create (protected by cross-process lock) ── # ── Layer 2: Backend discovery + create (protected by cross-process lock) ──
# Use a file lock so that two processes racing to create the same sandbox # Use a file lock so that two processes racing to create the same sandbox
@@ -485,6 +615,26 @@ class AioSandboxProvider(SandboxProvider):
return self._create_sandbox(thread_id, sandbox_id) return self._create_sandbox(thread_id, sandbox_id)
async def _acquire_internal_async(self, thread_id: str | None) -> str:
"""Async counterpart to ``_acquire_internal``."""
cached_id = self._reuse_in_process_sandbox(thread_id)
if cached_id is not None:
return cached_id
# Deterministic ID for thread-specific, random for anonymous
sandbox_id = self._sandbox_id_for_thread(thread_id)
# ── Layer 1.5: Warm pool (container still running, no cold-start) ──
reclaimed_id = self._reclaim_warm_pool_sandbox(thread_id, sandbox_id)
if reclaimed_id is not None:
return reclaimed_id
# ── Layer 2: Backend discovery + create (protected by cross-process lock) ──
if thread_id:
return await self._discover_or_create_with_lock_async(thread_id, sandbox_id)
return await self._create_sandbox_async(thread_id, sandbox_id)
def _discover_or_create_with_lock(self, thread_id: str, sandbox_id: str) -> str: def _discover_or_create_with_lock(self, thread_id: str, sandbox_id: str) -> str:
"""Discover an existing sandbox or create a new one under a cross-process file lock. """Discover an existing sandbox or create a new one under a cross-process file lock.
@@ -503,40 +653,50 @@ class AioSandboxProvider(SandboxProvider):
locked = True locked = True
# Re-check in-process caches under the file lock in case another # Re-check in-process caches under the file lock in case another
# thread in this process won the race while we were waiting. # thread in this process won the race while we were waiting.
with self._lock: cached_id = self._recheck_cached_sandbox(thread_id, sandbox_id)
if thread_id in self._thread_sandboxes: if cached_id is not None:
existing_id = self._thread_sandboxes[thread_id] return cached_id
if existing_id in self._sandboxes:
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id} (post-lock check)")
self._last_activity[existing_id] = time.time()
return existing_id
if sandbox_id in self._warm_pool:
info, _ = self._warm_pool.pop(sandbox_id)
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
self._sandboxes[sandbox_id] = sandbox
self._sandbox_infos[sandbox_id] = info
self._last_activity[sandbox_id] = time.time()
self._thread_sandboxes[thread_id] = sandbox_id
logger.info(f"Reclaimed warm-pool sandbox {sandbox_id} for thread {thread_id} (post-lock check)")
return sandbox_id
# Backend discovery: another process may have created the container. # Backend discovery: another process may have created the container.
discovered = self._backend.discover(sandbox_id) discovered = self._backend.discover(sandbox_id)
if discovered is not None: if discovered is not None:
sandbox = AioSandbox(id=discovered.sandbox_id, base_url=discovered.sandbox_url) return self._register_discovered_sandbox(thread_id, discovered)
with self._lock:
self._sandboxes[discovered.sandbox_id] = sandbox
self._sandbox_infos[discovered.sandbox_id] = discovered
self._last_activity[discovered.sandbox_id] = time.time()
self._thread_sandboxes[thread_id] = discovered.sandbox_id
logger.info(f"Discovered existing sandbox {discovered.sandbox_id} for thread {thread_id} at {discovered.sandbox_url}")
return discovered.sandbox_id
return self._create_sandbox(thread_id, sandbox_id) return self._create_sandbox(thread_id, sandbox_id)
finally: finally:
if locked: if locked:
_unlock_file(lock_file) _unlock_file(lock_file)
async def _discover_or_create_with_lock_async(self, thread_id: str, sandbox_id: str) -> str:
"""Async counterpart to ``_discover_or_create_with_lock``."""
paths = get_paths()
user_id = get_effective_user_id()
await asyncio.to_thread(paths.ensure_thread_dirs, thread_id, user_id=user_id)
lock_path = paths.thread_dir(thread_id, user_id=user_id) / f"{sandbox_id}.lock"
lock_file = await asyncio.to_thread(_open_lock_file, lock_path)
locked = False
try:
await asyncio.to_thread(_lock_file_exclusive, lock_file)
locked = True
# Re-check in-process caches under the file lock in case another
# thread in this process won the race while we were waiting.
cached_id = self._recheck_cached_sandbox(thread_id, sandbox_id)
if cached_id is not None:
return cached_id
# Backend discovery is sync because local discovery may inspect
# Docker and perform a health check; keep it off the event loop.
discovered = await asyncio.to_thread(self._backend.discover, sandbox_id)
if discovered is not None:
return self._register_discovered_sandbox(thread_id, discovered)
return await self._create_sandbox_async(thread_id, sandbox_id)
finally:
if locked:
await asyncio.to_thread(_unlock_file, lock_file)
await asyncio.to_thread(lock_file.close)
def _evict_oldest_warm(self) -> str | None: def _evict_oldest_warm(self) -> str | None:
"""Destroy the oldest container in the warm pool to free capacity. """Destroy the oldest container in the warm pool to free capacity.
@@ -574,18 +734,10 @@ class AioSandboxProvider(SandboxProvider):
# Enforce replicas: only warm-pool containers count toward eviction budget. # Enforce replicas: only warm-pool containers count toward eviction budget.
# Active sandboxes are in use by live threads and must not be forcibly stopped. # Active sandboxes are in use by live threads and must not be forcibly stopped.
replicas = self._config.get("replicas", DEFAULT_REPLICAS) replicas, total = self._replica_count()
with self._lock:
total = len(self._sandboxes) + len(self._warm_pool)
if total >= replicas: if total >= replicas:
evicted = self._evict_oldest_warm() evicted = self._evict_oldest_warm()
if evicted: self._log_replicas_soft_cap(replicas, sandbox_id, evicted)
logger.info(f"Evicted warm-pool sandbox {evicted} to stay within replicas={replicas}")
else:
# All slots are occupied by active sandboxes — proceed anyway and log.
# The replicas limit is a soft cap; we never forcibly stop a container
# that is actively serving a thread.
logger.warning(f"All {replicas} replica slots are in active use; creating sandbox {sandbox_id} beyond the soft limit")
info = self._backend.create(thread_id, sandbox_id, extra_mounts=extra_mounts or None) info = self._backend.create(thread_id, sandbox_id, extra_mounts=extra_mounts or None)
@@ -594,16 +746,27 @@ class AioSandboxProvider(SandboxProvider):
self._backend.destroy(info) self._backend.destroy(info)
raise RuntimeError(f"Sandbox {sandbox_id} failed to become ready within timeout at {info.sandbox_url}") raise RuntimeError(f"Sandbox {sandbox_id} failed to become ready within timeout at {info.sandbox_url}")
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url) return self._register_created_sandbox(thread_id, sandbox_id, info)
with self._lock:
self._sandboxes[sandbox_id] = sandbox
self._sandbox_infos[sandbox_id] = info
self._last_activity[sandbox_id] = time.time()
if thread_id:
self._thread_sandboxes[thread_id] = sandbox_id
logger.info(f"Created sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}") async def _create_sandbox_async(self, thread_id: str | None, sandbox_id: str) -> str:
return sandbox_id """Async counterpart to ``_create_sandbox``."""
extra_mounts = await asyncio.to_thread(self._get_extra_mounts, thread_id)
# Enforce replicas: only warm-pool containers count toward eviction budget.
# Active sandboxes are in use by live threads and must not be forcibly stopped.
replicas, total = self._replica_count()
if total >= replicas:
evicted = await asyncio.to_thread(self._evict_oldest_warm)
self._log_replicas_soft_cap(replicas, sandbox_id, evicted)
info = await asyncio.to_thread(self._backend.create, thread_id, sandbox_id, extra_mounts=extra_mounts or None)
# Wait for sandbox to be ready without blocking the event loop.
if not await wait_for_sandbox_ready_async(info.sandbox_url, timeout=60):
await asyncio.to_thread(self._backend.destroy, info)
raise RuntimeError(f"Sandbox {sandbox_id} failed to become ready within timeout at {info.sandbox_url}")
return self._register_created_sandbox(thread_id, sandbox_id, info)
def get(self, sandbox_id: str) -> Sandbox | None: def get(self, sandbox_id: str) -> Sandbox | None:
"""Get a sandbox by ID. Updates last activity timestamp. """Get a sandbox by ID. Updates last activity timestamp.
@@ -2,10 +2,12 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import httpx
import requests import requests
from .sandbox_info import SandboxInfo from .sandbox_info import SandboxInfo
@@ -35,6 +37,34 @@ def wait_for_sandbox_ready(sandbox_url: str, timeout: int = 30) -> bool:
return False return False
async def wait_for_sandbox_ready_async(sandbox_url: str, timeout: int = 30, poll_interval: float = 1.0) -> bool:
"""Async variant of sandbox readiness polling.
Use this from async runtime paths so sandbox startup waits do not block the
event loop. The synchronous ``wait_for_sandbox_ready`` function remains for
existing synchronous backend/provider call sites.
"""
loop = asyncio.get_running_loop()
deadline = loop.time() + timeout
async with httpx.AsyncClient(timeout=5) as client:
while True:
remaining = deadline - loop.time()
if remaining <= 0:
break
try:
response = await client.get(f"{sandbox_url}/v1/sandbox", timeout=min(5.0, remaining))
if response.status_code == 200:
return True
except httpx.RequestError:
pass
remaining = deadline - loop.time()
if remaining <= 0:
break
await asyncio.sleep(min(poll_interval, remaining))
return False
class SandboxBackend(ABC): class SandboxBackend(ABC):
"""Abstract base for sandbox provisioning backends. """Abstract base for sandbox provisioning backends.
@@ -44,7 +74,7 @@ class SandboxBackend(ABC):
""" """
@abstractmethod @abstractmethod
def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo: def create(self, thread_id: str | None, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
"""Create/provision a new sandbox. """Create/provision a new sandbox.
Args: Args:
@@ -241,7 +241,7 @@ class LocalContainerBackend(SandboxBackend):
# ── SandboxBackend interface ────────────────────────────────────────── # ── SandboxBackend interface ──────────────────────────────────────────
def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo: def create(self, thread_id: str | None, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
"""Start a new container and return its connection info. """Start a new container and return its connection info.
Args: Args:
@@ -59,7 +59,7 @@ class RemoteSandboxBackend(SandboxBackend):
def create( def create(
self, self,
thread_id: str, thread_id: str | None,
sandbox_id: str, sandbox_id: str,
extra_mounts: list[tuple[str, str, bool]] | None = None, extra_mounts: list[tuple[str, str, bool]] | None = None,
) -> SandboxInfo: ) -> SandboxInfo:
@@ -132,7 +132,7 @@ class RemoteSandboxBackend(SandboxBackend):
logger.warning("Provisioner list_running failed: %s", exc) logger.warning("Provisioner list_running failed: %s", exc)
return [] return []
def _provisioner_create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo: def _provisioner_create(self, thread_id: str | None, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
"""POST /api/sandboxes → create Pod + Service.""" """POST /api/sandboxes → create Pod + Service."""
try: try:
resp = requests.post( resp = requests.post(
@@ -1,3 +1,4 @@
import asyncio
import logging import logging
from typing import NotRequired, override from typing import NotRequired, override
@@ -48,6 +49,15 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
logger.info(f"Acquiring sandbox {sandbox_id}") logger.info(f"Acquiring sandbox {sandbox_id}")
return sandbox_id return sandbox_id
async def _acquire_sandbox_async(self, thread_id: str) -> str:
provider = get_sandbox_provider()
sandbox_id = await provider.acquire_async(thread_id)
logger.info(f"Acquiring sandbox {sandbox_id}")
return sandbox_id
async def _release_sandbox_async(self, sandbox_id: str) -> None:
await asyncio.to_thread(get_sandbox_provider().release, sandbox_id)
@override @override
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
# Skip acquisition if lazy_init is enabled # Skip acquisition if lazy_init is enabled
@@ -64,6 +74,23 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
return {"sandbox": {"sandbox_id": sandbox_id}} return {"sandbox": {"sandbox_id": sandbox_id}}
return super().before_agent(state, runtime) return super().before_agent(state, runtime)
@override
async def abefore_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
# Skip acquisition if lazy_init is enabled
if self._lazy_init:
return await super().abefore_agent(state, runtime)
# Eager initialization (original behavior), but use the async provider
# hook so blocking sandbox startup/polling runs outside the event loop.
if "sandbox" not in state or state["sandbox"] is None:
thread_id = (runtime.context or {}).get("thread_id")
if thread_id is None:
return await super().abefore_agent(state, runtime)
sandbox_id = await self._acquire_sandbox_async(thread_id)
logger.info(f"Assigned sandbox {sandbox_id} to thread {thread_id}")
return {"sandbox": {"sandbox_id": sandbox_id}}
return await super().abefore_agent(state, runtime)
@override @override
def after_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: def after_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
sandbox = state.get("sandbox") sandbox = state.get("sandbox")
@@ -81,3 +108,21 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
# No sandbox to release # No sandbox to release
return super().after_agent(state, runtime) return super().after_agent(state, runtime)
@override
async def aafter_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
sandbox = state.get("sandbox")
if sandbox is not None:
sandbox_id = sandbox["sandbox_id"]
logger.info(f"Releasing sandbox {sandbox_id}")
await self._release_sandbox_async(sandbox_id)
return None
if (runtime.context or {}).get("sandbox_id") is not None:
sandbox_id = runtime.context.get("sandbox_id")
logger.info(f"Releasing sandbox {sandbox_id} from context")
await self._release_sandbox_async(sandbox_id)
return None
# No sandbox to release
return await super().aafter_agent(state, runtime)
@@ -1,3 +1,4 @@
import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from deerflow.config import get_app_config from deerflow.config import get_app_config
@@ -19,6 +20,16 @@ class SandboxProvider(ABC):
""" """
pass pass
async def acquire_async(self, thread_id: str | None = None) -> str:
"""Acquire a sandbox without blocking the event loop.
Most sandbox providers expose a synchronous lifecycle API because local
Docker/provisioner operations are blocking. Async runtimes should call
this method so those blocking operations run in a worker thread instead
of stalling the event loop.
"""
return await asyncio.to_thread(self.acquire, thread_id)
@abstractmethod @abstractmethod
def get(self, sandbox_id: str) -> Sandbox | None: def get(self, sandbox_id: str) -> Sandbox | None:
"""Get a sandbox environment by ID. """Get a sandbox environment by ID.
@@ -1,6 +1,8 @@
import asyncio
import posixpath import posixpath
import re import re
import shlex import shlex
from collections.abc import Callable
from pathlib import Path from pathlib import Path
from langchain.tools import tool from langchain.tools import tool
@@ -1111,6 +1113,68 @@ def ensure_sandbox_initialized(runtime: Runtime | None = None) -> Sandbox:
return sandbox return sandbox
async def ensure_sandbox_initialized_async(runtime: Runtime | None = None) -> Sandbox:
"""Async counterpart to ``ensure_sandbox_initialized`` for tool runtimes.
This keeps lazy sandbox acquisition on the async provider hook, so AIO
sandbox startup and readiness polling do not fall back to synchronous
``provider.acquire()`` during async tool execution.
"""
if runtime is None:
raise SandboxRuntimeError("Tool runtime not available")
if runtime.state is None:
raise SandboxRuntimeError("Tool runtime state not available")
sandbox_state = runtime.state.get("sandbox")
if sandbox_state is not None:
sandbox_id = sandbox_state.get("sandbox_id")
if sandbox_id is not None:
sandbox = get_sandbox_provider().get(sandbox_id)
if sandbox is not None:
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id
return sandbox
thread_id = runtime.context.get("thread_id") if runtime.context else None
if thread_id is None:
thread_id = runtime.config.get("configurable", {}).get("thread_id") if runtime.config else None
if thread_id is None:
raise SandboxRuntimeError("Thread ID not available in runtime context")
provider = get_sandbox_provider()
sandbox_id = await provider.acquire_async(thread_id)
runtime.state["sandbox"] = {"sandbox_id": sandbox_id}
sandbox = provider.get(sandbox_id)
if sandbox is None:
raise SandboxNotFoundError("Sandbox not found after acquisition", sandbox_id=sandbox_id)
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id
return sandbox
async def _run_sync_tool_after_async_sandbox_init(
func: Callable[..., str] | None,
runtime: Runtime,
*args: object,
) -> str:
"""Initialize lazily via async provider, then run sync tool body off-thread."""
try:
await ensure_sandbox_initialized_async(runtime)
except SandboxError as e:
return f"Error: {e}"
except Exception as e:
return f"Error: Unexpected error initializing sandbox: {_sanitize_error(e, runtime)}"
if func is None:
return "Error: Tool implementation not available"
return await asyncio.to_thread(func, runtime, *args)
def ensure_thread_directories_exist(runtime: Runtime | None) -> None: def ensure_thread_directories_exist(runtime: Runtime | None) -> None:
"""Ensure thread data directories (workspace, uploads, outputs) exist. """Ensure thread data directories (workspace, uploads, outputs) exist.
@@ -1273,6 +1337,13 @@ def bash_tool(runtime: Runtime, description: str, command: str) -> str:
return f"Error: Unexpected error executing command: {_sanitize_error(e, runtime)}" return f"Error: Unexpected error executing command: {_sanitize_error(e, runtime)}"
async def _bash_tool_async(runtime: Runtime, description: str, command: str) -> str:
return await _run_sync_tool_after_async_sandbox_init(bash_tool.func, runtime, description, command)
bash_tool.coroutine = _bash_tool_async
@tool("ls", parse_docstring=True) @tool("ls", parse_docstring=True)
def ls_tool(runtime: Runtime, description: str, path: str) -> str: def ls_tool(runtime: Runtime, description: str, path: str) -> str:
"""List the contents of a directory up to 2 levels deep in tree format. """List the contents of a directory up to 2 levels deep in tree format.
@@ -1320,6 +1391,13 @@ def ls_tool(runtime: Runtime, description: str, path: str) -> str:
return f"Error: Unexpected error listing directory: {_sanitize_error(e, runtime)}" return f"Error: Unexpected error listing directory: {_sanitize_error(e, runtime)}"
async def _ls_tool_async(runtime: Runtime, description: str, path: str) -> str:
return await _run_sync_tool_after_async_sandbox_init(ls_tool.func, runtime, description, path)
ls_tool.coroutine = _ls_tool_async
@tool("glob", parse_docstring=True) @tool("glob", parse_docstring=True)
def glob_tool( def glob_tool(
runtime: Runtime, runtime: Runtime,
@@ -1370,6 +1448,28 @@ def glob_tool(
return f"Error: Unexpected error searching paths: {_sanitize_error(e, runtime)}" return f"Error: Unexpected error searching paths: {_sanitize_error(e, runtime)}"
async def _glob_tool_async(
runtime: Runtime,
description: str,
pattern: str,
path: str,
include_dirs: bool = False,
max_results: int = _DEFAULT_GLOB_MAX_RESULTS,
) -> str:
return await _run_sync_tool_after_async_sandbox_init(
glob_tool.func,
runtime,
description,
pattern,
path,
include_dirs,
max_results,
)
glob_tool.coroutine = _glob_tool_async
@tool("grep", parse_docstring=True) @tool("grep", parse_docstring=True)
def grep_tool( def grep_tool(
runtime: Runtime, runtime: Runtime,
@@ -1440,6 +1540,32 @@ def grep_tool(
return f"Error: Unexpected error searching file contents: {_sanitize_error(e, runtime)}" return f"Error: Unexpected error searching file contents: {_sanitize_error(e, runtime)}"
async def _grep_tool_async(
runtime: Runtime,
description: str,
pattern: str,
path: str,
glob: str | None = None,
literal: bool = False,
case_sensitive: bool = False,
max_results: int = _DEFAULT_GREP_MAX_RESULTS,
) -> str:
return await _run_sync_tool_after_async_sandbox_init(
grep_tool.func,
runtime,
description,
pattern,
path,
glob,
literal,
case_sensitive,
max_results,
)
grep_tool.coroutine = _grep_tool_async
@tool("read_file", parse_docstring=True) @tool("read_file", parse_docstring=True)
def read_file_tool( def read_file_tool(
runtime: Runtime, runtime: Runtime,
@@ -1495,6 +1621,19 @@ def read_file_tool(
return f"Error: Unexpected error reading file: {_sanitize_error(e, runtime)}" return f"Error: Unexpected error reading file: {_sanitize_error(e, runtime)}"
async def _read_file_tool_async(
runtime: Runtime,
description: str,
path: str,
start_line: int | None = None,
end_line: int | None = None,
) -> str:
return await _run_sync_tool_after_async_sandbox_init(read_file_tool.func, runtime, description, path, start_line, end_line)
read_file_tool.coroutine = _read_file_tool_async
@tool("write_file", parse_docstring=True) @tool("write_file", parse_docstring=True)
def write_file_tool( def write_file_tool(
runtime: Runtime, runtime: Runtime,
@@ -1536,6 +1675,19 @@ def write_file_tool(
return f"Error: Unexpected error writing file: {_sanitize_error(e, runtime)}" return f"Error: Unexpected error writing file: {_sanitize_error(e, runtime)}"
async def _write_file_tool_async(
runtime: Runtime,
description: str,
path: str,
content: str,
append: bool = False,
) -> str:
return await _run_sync_tool_after_async_sandbox_init(write_file_tool.func, runtime, description, path, content, append)
write_file_tool.coroutine = _write_file_tool_async
@tool("str_replace", parse_docstring=True) @tool("str_replace", parse_docstring=True)
def str_replace_tool( def str_replace_tool(
runtime: Runtime, runtime: Runtime,
@@ -1585,3 +1737,25 @@ def str_replace_tool(
return f"Error: Permission denied accessing file: {requested_path}" return f"Error: Permission denied accessing file: {requested_path}"
except Exception as e: except Exception as e:
return f"Error: Unexpected error replacing string: {_sanitize_error(e, runtime)}" return f"Error: Unexpected error replacing string: {_sanitize_error(e, runtime)}"
async def _str_replace_tool_async(
runtime: Runtime,
description: str,
path: str,
old_str: str,
new_str: str,
replace_all: bool = False,
) -> str:
return await _run_sync_tool_after_async_sandbox_init(
str_replace_tool.func,
runtime,
description,
path,
old_str,
new_str,
replace_all,
)
str_replace_tool.coroutine = _str_replace_tool_async
+177
View File
@@ -1,5 +1,6 @@
"""Tests for AioSandboxProvider mount helpers.""" """Tests for AioSandboxProvider mount helpers."""
import asyncio
import importlib import importlib
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@@ -140,6 +141,182 @@ def test_discover_or_create_only_unlocks_when_lock_succeeds(tmp_path, monkeypatc
assert unlock_calls == [] assert unlock_calls == []
@pytest.mark.anyio
async def test_acquire_async_uses_async_readiness_polling(monkeypatch):
"""AioSandboxProvider async creation must not use sync readiness polling."""
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
provider = _make_provider(None)
provider._config = {"replicas": 3}
provider._thread_locks = {}
provider._warm_pool = {}
provider._sandbox_infos = {}
provider._thread_sandboxes = {}
provider._last_activity = {}
provider._lock = aio_mod.threading.Lock()
provider._backend = SimpleNamespace(
create=MagicMock(return_value=aio_mod.SandboxInfo(sandbox_id="sandbox-async", sandbox_url="http://sandbox")),
destroy=MagicMock(),
discover=MagicMock(return_value=None),
)
async_readiness_calls: list[tuple[str, int]] = []
async def fake_wait_for_sandbox_ready_async(sandbox_url: str, timeout: int = 30, poll_interval: float = 1.0) -> bool:
async_readiness_calls.append((sandbox_url, timeout))
return True
monkeypatch.setattr(aio_mod, "wait_for_sandbox_ready_async", fake_wait_for_sandbox_ready_async)
monkeypatch.setattr(
aio_mod,
"wait_for_sandbox_ready",
lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("sync readiness should not be used")),
)
sandbox_id = await provider._create_sandbox_async("thread-async", "sandbox-async")
assert sandbox_id == "sandbox-async"
assert async_readiness_calls == [("http://sandbox", 60)]
assert provider._backend.destroy.call_count == 0
assert provider._thread_sandboxes["thread-async"] == "sandbox-async"
@pytest.mark.anyio
async def test_discover_or_create_with_lock_async_offloads_lock_file_open_and_close(tmp_path, monkeypatch):
"""Async lock path must not open or close lock files on the event loop."""
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
provider = _make_provider(tmp_path)
provider._discover_or_create_with_lock_async = aio_mod.AioSandboxProvider._discover_or_create_with_lock_async.__get__(
provider,
aio_mod.AioSandboxProvider,
)
provider._thread_locks = {}
provider._warm_pool = {}
provider._sandbox_infos = {}
provider._thread_sandboxes = {"thread-async-lock": "sandbox-async-lock"}
provider._sandboxes = {"sandbox-async-lock": aio_mod.AioSandbox(id="sandbox-async-lock", base_url="http://sandbox")}
provider._last_activity = {}
provider._lock = aio_mod.threading.Lock()
provider._backend = SimpleNamespace(discover=MagicMock(return_value=None))
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
to_thread_calls: list[object] = []
async def fake_to_thread(func, /, *args, **kwargs):
to_thread_calls.append(func)
return func(*args, **kwargs)
monkeypatch.setattr(aio_mod.asyncio, "to_thread", fake_to_thread)
sandbox_id = await provider._discover_or_create_with_lock_async("thread-async-lock", "sandbox-async-lock")
assert sandbox_id == "sandbox-async-lock"
assert aio_mod._open_lock_file in to_thread_calls
assert any(getattr(func, "__name__", "") == "close" for func in to_thread_calls)
@pytest.mark.anyio
async def test_acquire_thread_lock_async_uses_dedicated_executor(monkeypatch):
"""Per-thread lock waits should not consume the default asyncio.to_thread pool."""
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
lock = aio_mod.threading.Lock()
async def fail_to_thread(*_args, **_kwargs):
raise AssertionError("thread-lock acquisition must not use asyncio.to_thread")
monkeypatch.setattr(aio_mod.asyncio, "to_thread", fail_to_thread)
await aio_mod._acquire_thread_lock_async(lock)
try:
assert not lock.acquire(blocking=False)
finally:
lock.release()
@pytest.mark.anyio
async def test_acquire_async_cancellation_does_not_leak_thread_lock(tmp_path):
"""Cancelled async lock waiters must not leave the per-thread lock held."""
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
provider = _make_provider(tmp_path)
provider._thread_locks = {}
provider._warm_pool = {}
provider._sandbox_infos = {}
provider._thread_sandboxes = {}
provider._last_activity = {}
provider._lock = aio_mod.threading.Lock()
thread_id = "thread-cancel-lock"
thread_lock = provider._get_thread_lock(thread_id)
thread_lock.acquire()
task = asyncio.create_task(provider.acquire_async(thread_id))
await asyncio.sleep(0.05)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
thread_lock.release()
deadline = asyncio.get_running_loop().time() + 1
while asyncio.get_running_loop().time() < deadline:
acquired = thread_lock.acquire(blocking=False)
if acquired:
thread_lock.release()
return
await asyncio.sleep(0.01)
pytest.fail("provider thread lock was leaked after cancelling acquire_async")
@pytest.mark.anyio
async def test_acquire_async_cancelled_waiter_does_not_block_successor(tmp_path, monkeypatch):
"""A cancelled waiter must not prevent the next live waiter from acquiring."""
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
provider = _make_provider(tmp_path)
provider._thread_locks = {}
provider._warm_pool = {}
provider._sandbox_infos = {}
provider._thread_sandboxes = {}
provider._last_activity = {}
provider._lock = aio_mod.threading.Lock()
async def fake_acquire_internal_async(thread_id: str | None) -> str:
assert thread_id == "thread-successor-lock"
await asyncio.sleep(0)
return "sandbox-successor"
monkeypatch.setattr(provider, "_acquire_internal_async", fake_acquire_internal_async)
thread_id = "thread-successor-lock"
thread_lock = provider._get_thread_lock(thread_id)
thread_lock.acquire()
cancelled_waiter = asyncio.create_task(provider.acquire_async(thread_id))
await asyncio.sleep(0.05)
cancelled_waiter.cancel()
try:
await cancelled_waiter
except asyncio.CancelledError:
pass
live_waiter = asyncio.create_task(provider.acquire_async(thread_id))
thread_lock.release()
assert await asyncio.wait_for(live_waiter, timeout=1) == "sandbox-successor"
deadline = asyncio.get_running_loop().time() + 1
while asyncio.get_running_loop().time() < deadline:
acquired = thread_lock.acquire(blocking=False)
if acquired:
thread_lock.release()
return
await asyncio.sleep(0.01)
pytest.fail("provider thread lock was not released after successor acquire_async")
def test_remote_backend_create_forwards_effective_user_id(monkeypatch): def test_remote_backend_create_forwards_effective_user_id(monkeypatch):
"""Provisioner mode must receive user_id so PVC subPath matches user isolation.""" """Provisioner mode must receive user_id so PVC subPath matches user isolation."""
remote_mod = importlib.import_module("deerflow.community.aio_sandbox.remote_backend") remote_mod = importlib.import_module("deerflow.community.aio_sandbox.remote_backend")
+119
View File
@@ -0,0 +1,119 @@
from __future__ import annotations
from types import SimpleNamespace
import pytest
from deerflow.community.aio_sandbox import backend as readiness
class _FakeAsyncClient:
def __init__(self, *, responses: list[object], calls: list[str], timeout: float, request_timeouts: list[float] | None = None) -> None:
self._responses = responses
self._calls = calls
self._timeout = timeout
self._request_timeouts = request_timeouts
async def __aenter__(self) -> _FakeAsyncClient:
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
return None
async def get(self, url: str, *, timeout: float):
self._calls.append(url)
if self._request_timeouts is not None:
self._request_timeouts.append(timeout)
response = self._responses.pop(0)
if isinstance(response, BaseException):
raise response
return response
class _FakeLoop:
def __init__(self, times: list[float]) -> None:
self._times = times
self._index = 0
def time(self) -> float:
value = self._times[self._index]
self._index += 1
return value
@pytest.mark.anyio
async def test_wait_for_sandbox_ready_async_uses_nonblocking_polling(monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[str] = []
sleeps: list[float] = []
def fake_client(*, timeout: float):
return _FakeAsyncClient(
responses=[SimpleNamespace(status_code=503), SimpleNamespace(status_code=200)],
calls=calls,
timeout=timeout,
)
async def fake_sleep(delay: float) -> None:
sleeps.append(delay)
monkeypatch.setattr(readiness.httpx, "AsyncClient", fake_client)
monkeypatch.setattr(readiness.asyncio, "sleep", fake_sleep)
monkeypatch.setattr(readiness.requests, "get", lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("requests.get should not be used")))
monkeypatch.setattr(readiness.time, "sleep", lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("time.sleep should not be used")))
assert await readiness.wait_for_sandbox_ready_async("http://sandbox", timeout=5, poll_interval=0.05) is True
assert calls == ["http://sandbox/v1/sandbox", "http://sandbox/v1/sandbox"]
assert sleeps == [0.05]
@pytest.mark.anyio
async def test_wait_for_sandbox_ready_async_retries_request_errors(monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[str] = []
sleeps: list[float] = []
def fake_client(*, timeout: float):
return _FakeAsyncClient(
responses=[readiness.httpx.ConnectError("not ready"), SimpleNamespace(status_code=200)],
calls=calls,
timeout=timeout,
)
async def fake_sleep(delay: float) -> None:
sleeps.append(delay)
monkeypatch.setattr(readiness.httpx, "AsyncClient", fake_client)
monkeypatch.setattr(readiness.asyncio, "sleep", fake_sleep)
assert await readiness.wait_for_sandbox_ready_async("http://sandbox", timeout=5, poll_interval=0.01) is True
assert len(calls) == 2
assert sleeps == [0.01]
@pytest.mark.anyio
async def test_wait_for_sandbox_ready_async_clamps_request_and_sleep_to_deadline(monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[str] = []
request_timeouts: list[float] = []
sleeps: list[float] = []
def fake_client(*, timeout: float):
return _FakeAsyncClient(
responses=[SimpleNamespace(status_code=503)],
calls=calls,
timeout=timeout,
request_timeouts=request_timeouts,
)
async def fake_sleep(delay: float) -> None:
sleeps.append(delay)
monkeypatch.setattr(readiness.httpx, "AsyncClient", fake_client)
monkeypatch.setattr(readiness.asyncio, "sleep", fake_sleep)
monkeypatch.setattr(readiness.asyncio, "get_running_loop", lambda: _FakeLoop([100.0, 100.5, 101.75, 102.0]))
assert await readiness.wait_for_sandbox_ready_async("http://sandbox", timeout=2, poll_interval=1.0) is False
assert calls == ["http://sandbox/v1/sandbox"]
assert request_timeouts == [1.5]
assert sleeps == [0.25]
@@ -159,6 +159,26 @@ def test_provisioner_create_returns_sandbox_info(monkeypatch):
assert info.sandbox_url == "http://k3s:31001" assert info.sandbox_url == "http://k3s:31001"
def test_provisioner_create_accepts_anonymous_thread_id(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002")
def mock_post(url: str, json: dict, timeout: int):
assert url == "http://provisioner:8002/api/sandboxes"
assert json == {
"sandbox_id": "anon123",
"thread_id": None,
"user_id": "test-user-autouse",
}
assert timeout == 30
return _StubResponse(payload={"sandbox_id": "anon123", "sandbox_url": "http://k3s:31002"})
monkeypatch.setattr(requests, "post", mock_post)
info = backend.create(None, "anon123")
assert info.sandbox_id == "anon123"
assert info.sandbox_url == "http://k3s:31002"
def test_provisioner_create_raises_runtime_error_on_request_exception(monkeypatch): def test_provisioner_create_raises_runtime_error_on_request_exception(monkeypatch):
backend = RemoteSandboxBackend("http://provisioner:8002") backend = RemoteSandboxBackend("http://provisioner:8002")
+225
View File
@@ -0,0 +1,225 @@
from __future__ import annotations
import asyncio
import pytest
from langchain.agents.middleware import AgentMiddleware
from langchain.tools import ToolRuntime
from langgraph.runtime import Runtime
from deerflow.sandbox.middleware import SandboxMiddleware
from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import SandboxProvider, reset_sandbox_provider, set_sandbox_provider
from deerflow.sandbox.search import GrepMatch
from deerflow.sandbox.tools import ls_tool
class _SyncProvider(SandboxProvider):
def __init__(self) -> None:
self.thread_ids: list[str | None] = []
def acquire(self, thread_id: str | None = None) -> str:
self.thread_ids.append(thread_id)
return "sync-sandbox"
def get(self, sandbox_id: str) -> Sandbox | None:
return None
def release(self, sandbox_id: str) -> None:
return None
class _SandboxStub(Sandbox):
def execute_command(self, command: str) -> str:
return "OK"
def read_file(self, path: str) -> str:
return "content"
def download_file(self, path: str) -> bytes:
return b"content"
def list_dir(self, path: str, max_depth: int = 2) -> list[str]:
return ["/mnt/user-data/workspace/file.txt"]
def write_file(self, path: str, content: str, append: bool = False) -> None:
return None
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
return [], False
def grep(
self,
path: str,
pattern: str,
*,
glob: str | None = None,
literal: bool = False,
case_sensitive: bool = False,
max_results: int = 100,
) -> tuple[list[GrepMatch], bool]:
return [], False
def update_file(self, path: str, content: bytes) -> None:
return None
class _AsyncOnlyProvider(SandboxProvider):
def __init__(self) -> None:
self.thread_ids: list[str | None] = []
self.released_ids: list[str] = []
self.sandbox = _SandboxStub("async-sandbox")
def acquire(self, thread_id: str | None = None) -> str:
raise AssertionError("async middleware should not call sync acquire")
async def acquire_async(self, thread_id: str | None = None) -> str:
self.thread_ids.append(thread_id)
return "async-sandbox"
def get(self, sandbox_id: str) -> Sandbox | None:
if sandbox_id == "async-sandbox":
return self.sandbox
return None
def release(self, sandbox_id: str) -> None:
self.released_ids.append(sandbox_id)
return None
@pytest.mark.anyio
async def test_provider_default_acquire_async_offloads_sync_acquire(monkeypatch: pytest.MonkeyPatch) -> None:
provider = _SyncProvider()
calls: list[tuple[object, tuple[object, ...]]] = []
async def fake_to_thread(func, /, *args):
calls.append((func, args))
return func(*args)
monkeypatch.setattr(asyncio, "to_thread", fake_to_thread)
sandbox_id = await provider.acquire_async("thread-1")
assert sandbox_id == "sync-sandbox"
assert provider.thread_ids == ["thread-1"]
assert calls == [(provider.acquire, ("thread-1",))]
@pytest.mark.anyio
async def test_abefore_agent_uses_async_provider_acquire() -> None:
provider = _AsyncOnlyProvider()
set_sandbox_provider(provider)
try:
middleware = SandboxMiddleware(lazy_init=False)
result = await middleware.abefore_agent({}, Runtime(context={"thread_id": "thread-2"}))
finally:
reset_sandbox_provider()
assert result == {"sandbox": {"sandbox_id": "async-sandbox"}}
assert provider.thread_ids == ["thread-2"]
@pytest.mark.anyio
@pytest.mark.parametrize(
("middleware", "state", "runtime"),
[
(SandboxMiddleware(lazy_init=True), {}, Runtime(context={"thread_id": "thread-lazy"})),
(SandboxMiddleware(lazy_init=False), {}, Runtime(context={})),
(SandboxMiddleware(lazy_init=False), {"sandbox": {"sandbox_id": "existing"}}, Runtime(context={"thread_id": "thread-existing"})),
],
)
async def test_abefore_agent_delegates_to_super_when_not_acquiring(
monkeypatch: pytest.MonkeyPatch,
middleware: SandboxMiddleware,
state: dict,
runtime: Runtime,
) -> None:
calls: list[tuple[dict, Runtime]] = []
async def fake_super_abefore_agent(self, state_arg, runtime_arg):
calls.append((state_arg, runtime_arg))
return {"delegated": True}
monkeypatch.setattr(AgentMiddleware, "abefore_agent", fake_super_abefore_agent)
result = await middleware.abefore_agent(state, runtime)
assert result == {"delegated": True}
assert calls == [(state, runtime)]
@pytest.mark.anyio
async def test_default_lazy_tool_acquisition_uses_async_provider() -> None:
provider = _AsyncOnlyProvider()
set_sandbox_provider(provider)
try:
runtime = ToolRuntime(
state={},
context={"thread_id": "thread-lazy"},
config={"configurable": {}},
stream_writer=lambda _: None,
tools=[],
tool_call_id="call-1",
store=None,
)
result = await ls_tool.ainvoke({"runtime": runtime, "description": "list workspace", "path": "/mnt/user-data/workspace"})
finally:
reset_sandbox_provider()
assert result == "/mnt/user-data/workspace/file.txt"
assert provider.thread_ids == ["thread-lazy"]
assert runtime.state["sandbox"] == {"sandbox_id": "async-sandbox"}
assert runtime.context["sandbox_id"] == "async-sandbox"
@pytest.mark.anyio
@pytest.mark.parametrize(
("state", "runtime", "expected_sandbox_id"),
[
({"sandbox": {"sandbox_id": "state-sandbox"}}, Runtime(context={}), "state-sandbox"),
({}, Runtime(context={"sandbox_id": "context-sandbox"}), "context-sandbox"),
],
)
async def test_aafter_agent_releases_sandbox_off_thread(
monkeypatch: pytest.MonkeyPatch,
state: dict,
runtime: Runtime,
expected_sandbox_id: str,
) -> None:
provider = _AsyncOnlyProvider()
to_thread_calls: list[tuple[object, tuple[object, ...]]] = []
async def fake_to_thread(func, /, *args):
to_thread_calls.append((func, args))
return func(*args)
monkeypatch.setattr(asyncio, "to_thread", fake_to_thread)
set_sandbox_provider(provider)
try:
result = await SandboxMiddleware().aafter_agent(state, runtime)
finally:
reset_sandbox_provider()
assert result is None
assert provider.released_ids == [expected_sandbox_id]
assert to_thread_calls == [(provider.release, (expected_sandbox_id,))]
@pytest.mark.anyio
async def test_aafter_agent_delegates_to_super_when_no_sandbox(monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[tuple[dict, Runtime]] = []
async def fake_super_aafter_agent(self, state_arg, runtime_arg):
calls.append((state_arg, runtime_arg))
return {"delegated": True}
monkeypatch.setattr(AgentMiddleware, "aafter_agent", fake_super_aafter_agent)
state = {}
runtime = Runtime(context={})
result = await SandboxMiddleware().aafter_agent(state, runtime)
assert result == {"delegated": True}
assert calls == [(state, runtime)]