mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 07:26:50 +00:00
fix(sandbox): avoid blocking sandbox readiness polling (#2822)
* fix(sandbox): offload async sandbox acquisition Run blocking sandbox provider acquisition through the async provider hook so eager sandbox setup does not stall the event loop. * fix(sandbox): add async readiness polling Introduce an async sandbox readiness poller using httpx and asyncio.sleep while preserving the existing synchronous API. * test(sandbox): cover async readiness polling Lock in non-blocking readiness behavior so the async helper does not regress to requests.get or time.sleep. * fix(sandbox): allow anonymous backend creation * fix(sandbox): use async readiness in provider acquisition * fix(sandbox): use async acquisition for lazy tools * test(sandbox): cover anonymous remote creation * fix(sandbox): clamp async readiness timeout budget * fix(sandbox): offload async lock file handling * fix(sandbox): delegate async middleware fallthrough * docs(sandbox): document async acquisition path * fix(sandbox): offload async sandbox release * docs(sandbox): mention async release hook * fix(sandbox): address async lock review Reduce duplicate sync/async sandbox acquisition state handling and move async thread-lock waits onto a dedicated executor with cancellation-safe cleanup. * chore: retrigger ci Retrigger GitHub Actions after upstream main fixed the stale PR merge lint failure. * test(sandbox): sync backend unit fixtures --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
+1
-1
@@ -236,7 +236,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
||||
### Sandbox System (`packages/harness/deerflow/sandbox/`)
|
||||
|
||||
**Interface**: Abstract `Sandbox` with `execute_command`, `read_file`, `write_file`, `list_dir`
|
||||
**Provider Pattern**: `SandboxProvider` with `acquire`, `get`, `release` lifecycle
|
||||
**Provider Pattern**: `SandboxProvider` with `acquire`, `acquire_async`, `get`, `release` lifecycle. Async agent/tool paths call async sandbox lifecycle hooks so Docker sandbox creation, discovery, cross-process locking, readiness polling, and release stay off the event loop.
|
||||
**Implementations**:
|
||||
- `LocalSandboxProvider` - Local filesystem execution. `acquire(thread_id)` returns a per-thread `LocalSandbox` (id `local:{thread_id}`) whose `path_mappings` resolve `/mnt/user-data/{workspace,uploads,outputs}` and `/mnt/acp-workspace` to that thread's host directories, so the public `Sandbox` API honours the `/mnt/user-data` contract uniformly with AIO. `acquire()` / `acquire(None)` keeps the legacy generic singleton (id `local`) for callers without a thread context. Per-thread sandboxes are held in an LRU cache (default 256 entries) guarded by a `threading.Lock`.
|
||||
- `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation
|
||||
|
||||
+1
-1
@@ -69,7 +69,7 @@ Middlewares execute in strict order, each handling a specific concern:
|
||||
Per-thread isolated execution with virtual path translation:
|
||||
|
||||
- **Abstract interface**: `execute_command`, `read_file`, `write_file`, `list_dir`
|
||||
- **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/)
|
||||
- **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/). Async runtime paths use async sandbox lifecycle hooks so startup, readiness polling, and release do not block the event loop.
|
||||
- **Virtual paths**: `/mnt/user-data/{workspace,uploads,outputs}` → thread-specific physical directories
|
||||
- **Skills path**: `/mnt/skills` → `deer-flow/skills/` directory
|
||||
- **Skills loading**: Recursively discovers nested `SKILL.md` files under `skills/{public,custom}` and preserves nested container paths
|
||||
|
||||
@@ -10,6 +10,7 @@ The provider itself handles:
|
||||
- Mount computation (thread-specific, skills)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import hashlib
|
||||
import logging
|
||||
@@ -18,6 +19,7 @@ import signal
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
try:
|
||||
import fcntl
|
||||
@@ -32,7 +34,7 @@ from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.sandbox_provider import SandboxProvider
|
||||
|
||||
from .aio_sandbox import AioSandbox
|
||||
from .backend import SandboxBackend, wait_for_sandbox_ready
|
||||
from .backend import SandboxBackend, wait_for_sandbox_ready, wait_for_sandbox_ready_async
|
||||
from .local_backend import LocalContainerBackend
|
||||
from .remote_backend import RemoteSandboxBackend
|
||||
from .sandbox_info import SandboxInfo
|
||||
@@ -46,6 +48,9 @@ DEFAULT_CONTAINER_PREFIX = "deer-flow-sandbox"
|
||||
DEFAULT_IDLE_TIMEOUT = 600 # 10 minutes in seconds
|
||||
DEFAULT_REPLICAS = 3 # Maximum concurrent sandbox containers
|
||||
IDLE_CHECK_INTERVAL = 60 # Check every 60 seconds
|
||||
THREAD_LOCK_EXECUTOR_WORKERS = min(32, (os.cpu_count() or 1) + 4)
|
||||
_THREAD_LOCK_EXECUTOR = ThreadPoolExecutor(max_workers=THREAD_LOCK_EXECUTOR_WORKERS, thread_name_prefix="sandbox-lock-wait")
|
||||
atexit.register(_THREAD_LOCK_EXECUTOR.shutdown, wait=False, cancel_futures=True)
|
||||
|
||||
|
||||
def _lock_file_exclusive(lock_file) -> None:
|
||||
@@ -66,6 +71,40 @@ def _unlock_file(lock_file) -> None:
|
||||
msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1)
|
||||
|
||||
|
||||
def _open_lock_file(lock_path):
|
||||
return open(lock_path, "a", encoding="utf-8")
|
||||
|
||||
|
||||
async def _acquire_thread_lock_async(lock: threading.Lock) -> None:
|
||||
"""Acquire a threading.Lock without polling or using the default executor."""
|
||||
loop = asyncio.get_running_loop()
|
||||
acquire_future = loop.run_in_executor(_THREAD_LOCK_EXECUTOR, lock.acquire, True)
|
||||
|
||||
try:
|
||||
acquired = await asyncio.shield(acquire_future)
|
||||
except asyncio.CancelledError:
|
||||
acquire_future.add_done_callback(lambda task: _release_cancelled_lock_acquire(lock, task))
|
||||
raise
|
||||
|
||||
if not acquired:
|
||||
raise RuntimeError("Failed to acquire sandbox thread lock")
|
||||
|
||||
|
||||
def _release_cancelled_lock_acquire(lock: threading.Lock, task: asyncio.Future[bool]) -> None:
|
||||
"""Release a lock acquired after its awaiting coroutine was cancelled."""
|
||||
if task.cancelled():
|
||||
return
|
||||
|
||||
try:
|
||||
acquired = task.result()
|
||||
except Exception as e:
|
||||
logger.warning(f"Cancelled sandbox lock acquisition finished with error: {e}")
|
||||
return
|
||||
|
||||
if acquired:
|
||||
lock.release()
|
||||
|
||||
|
||||
class AioSandboxProvider(SandboxProvider):
|
||||
"""Sandbox provider that manages containers running the AIO sandbox.
|
||||
|
||||
@@ -416,6 +455,96 @@ class AioSandboxProvider(SandboxProvider):
|
||||
self._thread_locks[thread_id] = threading.Lock()
|
||||
return self._thread_locks[thread_id]
|
||||
|
||||
def _sandbox_id_for_thread(self, thread_id: str | None) -> str:
|
||||
"""Return deterministic IDs for thread sandboxes and random IDs otherwise."""
|
||||
return self._deterministic_sandbox_id(thread_id) if thread_id else str(uuid.uuid4())[:8]
|
||||
|
||||
def _reuse_in_process_sandbox(self, thread_id: str | None, *, post_lock: bool = False) -> str | None:
|
||||
"""Reuse an active in-process sandbox for a thread if one is still tracked."""
|
||||
if thread_id is None:
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
if thread_id not in self._thread_sandboxes:
|
||||
return None
|
||||
|
||||
existing_id = self._thread_sandboxes[thread_id]
|
||||
if existing_id in self._sandboxes:
|
||||
suffix = " (post-lock check)" if post_lock else ""
|
||||
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}{suffix}")
|
||||
self._last_activity[existing_id] = time.time()
|
||||
return existing_id
|
||||
|
||||
del self._thread_sandboxes[thread_id]
|
||||
return None
|
||||
|
||||
def _reclaim_warm_pool_sandbox(self, thread_id: str | None, sandbox_id: str, *, post_lock: bool = False) -> str | None:
|
||||
"""Promote a warm-pool sandbox back to active tracking if available."""
|
||||
if thread_id is None:
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
if sandbox_id not in self._warm_pool:
|
||||
return None
|
||||
|
||||
info, _ = self._warm_pool.pop(sandbox_id)
|
||||
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
|
||||
self._sandboxes[sandbox_id] = sandbox
|
||||
self._sandbox_infos[sandbox_id] = info
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
self._thread_sandboxes[thread_id] = sandbox_id
|
||||
|
||||
suffix = " (post-lock check)" if post_lock else f" at {info.sandbox_url}"
|
||||
logger.info(f"Reclaimed warm-pool sandbox {sandbox_id} for thread {thread_id}{suffix}")
|
||||
return sandbox_id
|
||||
|
||||
def _recheck_cached_sandbox(self, thread_id: str, sandbox_id: str) -> str | None:
|
||||
"""Re-check in-memory caches after acquiring the cross-process file lock."""
|
||||
return self._reuse_in_process_sandbox(thread_id, post_lock=True) or self._reclaim_warm_pool_sandbox(thread_id, sandbox_id, post_lock=True)
|
||||
|
||||
def _register_discovered_sandbox(self, thread_id: str, info: SandboxInfo) -> str:
|
||||
"""Track a sandbox discovered through the backend."""
|
||||
sandbox = AioSandbox(id=info.sandbox_id, base_url=info.sandbox_url)
|
||||
with self._lock:
|
||||
self._sandboxes[info.sandbox_id] = sandbox
|
||||
self._sandbox_infos[info.sandbox_id] = info
|
||||
self._last_activity[info.sandbox_id] = time.time()
|
||||
self._thread_sandboxes[thread_id] = info.sandbox_id
|
||||
|
||||
logger.info(f"Discovered existing sandbox {info.sandbox_id} for thread {thread_id} at {info.sandbox_url}")
|
||||
return info.sandbox_id
|
||||
|
||||
def _register_created_sandbox(self, thread_id: str | None, sandbox_id: str, info: SandboxInfo) -> str:
|
||||
"""Track a newly-created sandbox in the active maps."""
|
||||
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
|
||||
with self._lock:
|
||||
self._sandboxes[sandbox_id] = sandbox
|
||||
self._sandbox_infos[sandbox_id] = info
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
if thread_id:
|
||||
self._thread_sandboxes[thread_id] = sandbox_id
|
||||
|
||||
logger.info(f"Created sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}")
|
||||
return sandbox_id
|
||||
|
||||
def _replica_count(self) -> tuple[int, int]:
|
||||
"""Return configured replicas and currently tracked sandbox count."""
|
||||
replicas = self._config.get("replicas", DEFAULT_REPLICAS)
|
||||
with self._lock:
|
||||
total = len(self._sandboxes) + len(self._warm_pool)
|
||||
return replicas, total
|
||||
|
||||
def _log_replicas_soft_cap(self, replicas: int, sandbox_id: str, evicted: str | None) -> None:
|
||||
"""Log the result of enforcing the warm-pool replica budget."""
|
||||
if evicted:
|
||||
logger.info(f"Evicted warm-pool sandbox {evicted} to stay within replicas={replicas}")
|
||||
return
|
||||
|
||||
# All slots are occupied by active sandboxes — proceed anyway and log.
|
||||
# The replicas limit is a soft cap; we never forcibly stop a container
|
||||
# that is actively serving a thread.
|
||||
logger.warning(f"All {replicas} replica slots are in active use; creating sandbox {sandbox_id} beyond the soft limit")
|
||||
|
||||
# ── Core: acquire / get / release / shutdown ─────────────────────────
|
||||
|
||||
def acquire(self, thread_id: str | None = None) -> str:
|
||||
@@ -440,6 +569,23 @@ class AioSandboxProvider(SandboxProvider):
|
||||
else:
|
||||
return self._acquire_internal(thread_id)
|
||||
|
||||
async def acquire_async(self, thread_id: str | None = None) -> str:
|
||||
"""Acquire a sandbox environment without blocking the event loop.
|
||||
|
||||
Mirrors ``acquire()`` while keeping blocking backend operations off the
|
||||
event loop and using async-native readiness polling for newly created
|
||||
sandboxes.
|
||||
"""
|
||||
if thread_id:
|
||||
thread_lock = self._get_thread_lock(thread_id)
|
||||
await _acquire_thread_lock_async(thread_lock)
|
||||
try:
|
||||
return await self._acquire_internal_async(thread_id)
|
||||
finally:
|
||||
thread_lock.release()
|
||||
|
||||
return await self._acquire_internal_async(thread_id)
|
||||
|
||||
def _acquire_internal(self, thread_id: str | None) -> str:
|
||||
"""Internal sandbox acquisition with two-layer consistency.
|
||||
|
||||
@@ -448,33 +594,17 @@ class AioSandboxProvider(SandboxProvider):
|
||||
sandbox_id is deterministic from thread_id so no shared state file
|
||||
is needed — any process can derive the same container name)
|
||||
"""
|
||||
# ── Layer 1: In-process cache (fast path) ──
|
||||
if thread_id:
|
||||
with self._lock:
|
||||
if thread_id in self._thread_sandboxes:
|
||||
existing_id = self._thread_sandboxes[thread_id]
|
||||
if existing_id in self._sandboxes:
|
||||
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}")
|
||||
self._last_activity[existing_id] = time.time()
|
||||
return existing_id
|
||||
else:
|
||||
del self._thread_sandboxes[thread_id]
|
||||
cached_id = self._reuse_in_process_sandbox(thread_id)
|
||||
if cached_id is not None:
|
||||
return cached_id
|
||||
|
||||
# Deterministic ID for thread-specific, random for anonymous
|
||||
sandbox_id = self._deterministic_sandbox_id(thread_id) if thread_id else str(uuid.uuid4())[:8]
|
||||
sandbox_id = self._sandbox_id_for_thread(thread_id)
|
||||
|
||||
# ── Layer 1.5: Warm pool (container still running, no cold-start) ──
|
||||
if thread_id:
|
||||
with self._lock:
|
||||
if sandbox_id in self._warm_pool:
|
||||
info, _ = self._warm_pool.pop(sandbox_id)
|
||||
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
|
||||
self._sandboxes[sandbox_id] = sandbox
|
||||
self._sandbox_infos[sandbox_id] = info
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
self._thread_sandboxes[thread_id] = sandbox_id
|
||||
logger.info(f"Reclaimed warm-pool sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}")
|
||||
return sandbox_id
|
||||
reclaimed_id = self._reclaim_warm_pool_sandbox(thread_id, sandbox_id)
|
||||
if reclaimed_id is not None:
|
||||
return reclaimed_id
|
||||
|
||||
# ── Layer 2: Backend discovery + create (protected by cross-process lock) ──
|
||||
# Use a file lock so that two processes racing to create the same sandbox
|
||||
@@ -485,6 +615,26 @@ class AioSandboxProvider(SandboxProvider):
|
||||
|
||||
return self._create_sandbox(thread_id, sandbox_id)
|
||||
|
||||
async def _acquire_internal_async(self, thread_id: str | None) -> str:
|
||||
"""Async counterpart to ``_acquire_internal``."""
|
||||
cached_id = self._reuse_in_process_sandbox(thread_id)
|
||||
if cached_id is not None:
|
||||
return cached_id
|
||||
|
||||
# Deterministic ID for thread-specific, random for anonymous
|
||||
sandbox_id = self._sandbox_id_for_thread(thread_id)
|
||||
|
||||
# ── Layer 1.5: Warm pool (container still running, no cold-start) ──
|
||||
reclaimed_id = self._reclaim_warm_pool_sandbox(thread_id, sandbox_id)
|
||||
if reclaimed_id is not None:
|
||||
return reclaimed_id
|
||||
|
||||
# ── Layer 2: Backend discovery + create (protected by cross-process lock) ──
|
||||
if thread_id:
|
||||
return await self._discover_or_create_with_lock_async(thread_id, sandbox_id)
|
||||
|
||||
return await self._create_sandbox_async(thread_id, sandbox_id)
|
||||
|
||||
def _discover_or_create_with_lock(self, thread_id: str, sandbox_id: str) -> str:
|
||||
"""Discover an existing sandbox or create a new one under a cross-process file lock.
|
||||
|
||||
@@ -503,40 +653,50 @@ class AioSandboxProvider(SandboxProvider):
|
||||
locked = True
|
||||
# Re-check in-process caches under the file lock in case another
|
||||
# thread in this process won the race while we were waiting.
|
||||
with self._lock:
|
||||
if thread_id in self._thread_sandboxes:
|
||||
existing_id = self._thread_sandboxes[thread_id]
|
||||
if existing_id in self._sandboxes:
|
||||
logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id} (post-lock check)")
|
||||
self._last_activity[existing_id] = time.time()
|
||||
return existing_id
|
||||
if sandbox_id in self._warm_pool:
|
||||
info, _ = self._warm_pool.pop(sandbox_id)
|
||||
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
|
||||
self._sandboxes[sandbox_id] = sandbox
|
||||
self._sandbox_infos[sandbox_id] = info
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
self._thread_sandboxes[thread_id] = sandbox_id
|
||||
logger.info(f"Reclaimed warm-pool sandbox {sandbox_id} for thread {thread_id} (post-lock check)")
|
||||
return sandbox_id
|
||||
cached_id = self._recheck_cached_sandbox(thread_id, sandbox_id)
|
||||
if cached_id is not None:
|
||||
return cached_id
|
||||
|
||||
# Backend discovery: another process may have created the container.
|
||||
discovered = self._backend.discover(sandbox_id)
|
||||
if discovered is not None:
|
||||
sandbox = AioSandbox(id=discovered.sandbox_id, base_url=discovered.sandbox_url)
|
||||
with self._lock:
|
||||
self._sandboxes[discovered.sandbox_id] = sandbox
|
||||
self._sandbox_infos[discovered.sandbox_id] = discovered
|
||||
self._last_activity[discovered.sandbox_id] = time.time()
|
||||
self._thread_sandboxes[thread_id] = discovered.sandbox_id
|
||||
logger.info(f"Discovered existing sandbox {discovered.sandbox_id} for thread {thread_id} at {discovered.sandbox_url}")
|
||||
return discovered.sandbox_id
|
||||
return self._register_discovered_sandbox(thread_id, discovered)
|
||||
|
||||
return self._create_sandbox(thread_id, sandbox_id)
|
||||
finally:
|
||||
if locked:
|
||||
_unlock_file(lock_file)
|
||||
|
||||
async def _discover_or_create_with_lock_async(self, thread_id: str, sandbox_id: str) -> str:
|
||||
"""Async counterpart to ``_discover_or_create_with_lock``."""
|
||||
paths = get_paths()
|
||||
user_id = get_effective_user_id()
|
||||
await asyncio.to_thread(paths.ensure_thread_dirs, thread_id, user_id=user_id)
|
||||
lock_path = paths.thread_dir(thread_id, user_id=user_id) / f"{sandbox_id}.lock"
|
||||
|
||||
lock_file = await asyncio.to_thread(_open_lock_file, lock_path)
|
||||
locked = False
|
||||
try:
|
||||
await asyncio.to_thread(_lock_file_exclusive, lock_file)
|
||||
locked = True
|
||||
# Re-check in-process caches under the file lock in case another
|
||||
# thread in this process won the race while we were waiting.
|
||||
cached_id = self._recheck_cached_sandbox(thread_id, sandbox_id)
|
||||
if cached_id is not None:
|
||||
return cached_id
|
||||
|
||||
# Backend discovery is sync because local discovery may inspect
|
||||
# Docker and perform a health check; keep it off the event loop.
|
||||
discovered = await asyncio.to_thread(self._backend.discover, sandbox_id)
|
||||
if discovered is not None:
|
||||
return self._register_discovered_sandbox(thread_id, discovered)
|
||||
|
||||
return await self._create_sandbox_async(thread_id, sandbox_id)
|
||||
finally:
|
||||
if locked:
|
||||
await asyncio.to_thread(_unlock_file, lock_file)
|
||||
await asyncio.to_thread(lock_file.close)
|
||||
|
||||
def _evict_oldest_warm(self) -> str | None:
|
||||
"""Destroy the oldest container in the warm pool to free capacity.
|
||||
|
||||
@@ -574,18 +734,10 @@ class AioSandboxProvider(SandboxProvider):
|
||||
|
||||
# Enforce replicas: only warm-pool containers count toward eviction budget.
|
||||
# Active sandboxes are in use by live threads and must not be forcibly stopped.
|
||||
replicas = self._config.get("replicas", DEFAULT_REPLICAS)
|
||||
with self._lock:
|
||||
total = len(self._sandboxes) + len(self._warm_pool)
|
||||
replicas, total = self._replica_count()
|
||||
if total >= replicas:
|
||||
evicted = self._evict_oldest_warm()
|
||||
if evicted:
|
||||
logger.info(f"Evicted warm-pool sandbox {evicted} to stay within replicas={replicas}")
|
||||
else:
|
||||
# All slots are occupied by active sandboxes — proceed anyway and log.
|
||||
# The replicas limit is a soft cap; we never forcibly stop a container
|
||||
# that is actively serving a thread.
|
||||
logger.warning(f"All {replicas} replica slots are in active use; creating sandbox {sandbox_id} beyond the soft limit")
|
||||
self._log_replicas_soft_cap(replicas, sandbox_id, evicted)
|
||||
|
||||
info = self._backend.create(thread_id, sandbox_id, extra_mounts=extra_mounts or None)
|
||||
|
||||
@@ -594,16 +746,27 @@ class AioSandboxProvider(SandboxProvider):
|
||||
self._backend.destroy(info)
|
||||
raise RuntimeError(f"Sandbox {sandbox_id} failed to become ready within timeout at {info.sandbox_url}")
|
||||
|
||||
sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url)
|
||||
with self._lock:
|
||||
self._sandboxes[sandbox_id] = sandbox
|
||||
self._sandbox_infos[sandbox_id] = info
|
||||
self._last_activity[sandbox_id] = time.time()
|
||||
if thread_id:
|
||||
self._thread_sandboxes[thread_id] = sandbox_id
|
||||
return self._register_created_sandbox(thread_id, sandbox_id, info)
|
||||
|
||||
logger.info(f"Created sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}")
|
||||
return sandbox_id
|
||||
async def _create_sandbox_async(self, thread_id: str | None, sandbox_id: str) -> str:
|
||||
"""Async counterpart to ``_create_sandbox``."""
|
||||
extra_mounts = await asyncio.to_thread(self._get_extra_mounts, thread_id)
|
||||
|
||||
# Enforce replicas: only warm-pool containers count toward eviction budget.
|
||||
# Active sandboxes are in use by live threads and must not be forcibly stopped.
|
||||
replicas, total = self._replica_count()
|
||||
if total >= replicas:
|
||||
evicted = await asyncio.to_thread(self._evict_oldest_warm)
|
||||
self._log_replicas_soft_cap(replicas, sandbox_id, evicted)
|
||||
|
||||
info = await asyncio.to_thread(self._backend.create, thread_id, sandbox_id, extra_mounts=extra_mounts or None)
|
||||
|
||||
# Wait for sandbox to be ready without blocking the event loop.
|
||||
if not await wait_for_sandbox_ready_async(info.sandbox_url, timeout=60):
|
||||
await asyncio.to_thread(self._backend.destroy, info)
|
||||
raise RuntimeError(f"Sandbox {sandbox_id} failed to become ready within timeout at {info.sandbox_url}")
|
||||
|
||||
return self._register_created_sandbox(thread_id, sandbox_id, info)
|
||||
|
||||
def get(self, sandbox_id: str) -> Sandbox | None:
|
||||
"""Get a sandbox by ID. Updates last activity timestamp.
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
from .sandbox_info import SandboxInfo
|
||||
@@ -35,6 +37,34 @@ def wait_for_sandbox_ready(sandbox_url: str, timeout: int = 30) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def wait_for_sandbox_ready_async(sandbox_url: str, timeout: int = 30, poll_interval: float = 1.0) -> bool:
|
||||
"""Async variant of sandbox readiness polling.
|
||||
|
||||
Use this from async runtime paths so sandbox startup waits do not block the
|
||||
event loop. The synchronous ``wait_for_sandbox_ready`` function remains for
|
||||
existing synchronous backend/provider call sites.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
deadline = loop.time() + timeout
|
||||
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
while True:
|
||||
remaining = deadline - loop.time()
|
||||
if remaining <= 0:
|
||||
break
|
||||
try:
|
||||
response = await client.get(f"{sandbox_url}/v1/sandbox", timeout=min(5.0, remaining))
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except httpx.RequestError:
|
||||
pass
|
||||
remaining = deadline - loop.time()
|
||||
if remaining <= 0:
|
||||
break
|
||||
await asyncio.sleep(min(poll_interval, remaining))
|
||||
return False
|
||||
|
||||
|
||||
class SandboxBackend(ABC):
|
||||
"""Abstract base for sandbox provisioning backends.
|
||||
|
||||
@@ -44,7 +74,7 @@ class SandboxBackend(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
|
||||
def create(self, thread_id: str | None, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
|
||||
"""Create/provision a new sandbox.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -241,7 +241,7 @@ class LocalContainerBackend(SandboxBackend):
|
||||
|
||||
# ── SandboxBackend interface ──────────────────────────────────────────
|
||||
|
||||
def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
|
||||
def create(self, thread_id: str | None, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
|
||||
"""Start a new container and return its connection info.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -59,7 +59,7 @@ class RemoteSandboxBackend(SandboxBackend):
|
||||
|
||||
def create(
|
||||
self,
|
||||
thread_id: str,
|
||||
thread_id: str | None,
|
||||
sandbox_id: str,
|
||||
extra_mounts: list[tuple[str, str, bool]] | None = None,
|
||||
) -> SandboxInfo:
|
||||
@@ -132,7 +132,7 @@ class RemoteSandboxBackend(SandboxBackend):
|
||||
logger.warning("Provisioner list_running failed: %s", exc)
|
||||
return []
|
||||
|
||||
def _provisioner_create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
|
||||
def _provisioner_create(self, thread_id: str | None, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo:
|
||||
"""POST /api/sandboxes → create Pod + Service."""
|
||||
try:
|
||||
resp = requests.post(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import NotRequired, override
|
||||
|
||||
@@ -48,6 +49,15 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
||||
logger.info(f"Acquiring sandbox {sandbox_id}")
|
||||
return sandbox_id
|
||||
|
||||
async def _acquire_sandbox_async(self, thread_id: str) -> str:
|
||||
provider = get_sandbox_provider()
|
||||
sandbox_id = await provider.acquire_async(thread_id)
|
||||
logger.info(f"Acquiring sandbox {sandbox_id}")
|
||||
return sandbox_id
|
||||
|
||||
async def _release_sandbox_async(self, sandbox_id: str) -> None:
|
||||
await asyncio.to_thread(get_sandbox_provider().release, sandbox_id)
|
||||
|
||||
@override
|
||||
def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
# Skip acquisition if lazy_init is enabled
|
||||
@@ -64,6 +74,23 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
||||
return {"sandbox": {"sandbox_id": sandbox_id}}
|
||||
return super().before_agent(state, runtime)
|
||||
|
||||
@override
|
||||
async def abefore_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
# Skip acquisition if lazy_init is enabled
|
||||
if self._lazy_init:
|
||||
return await super().abefore_agent(state, runtime)
|
||||
|
||||
# Eager initialization (original behavior), but use the async provider
|
||||
# hook so blocking sandbox startup/polling runs outside the event loop.
|
||||
if "sandbox" not in state or state["sandbox"] is None:
|
||||
thread_id = (runtime.context or {}).get("thread_id")
|
||||
if thread_id is None:
|
||||
return await super().abefore_agent(state, runtime)
|
||||
sandbox_id = await self._acquire_sandbox_async(thread_id)
|
||||
logger.info(f"Assigned sandbox {sandbox_id} to thread {thread_id}")
|
||||
return {"sandbox": {"sandbox_id": sandbox_id}}
|
||||
return await super().abefore_agent(state, runtime)
|
||||
|
||||
@override
|
||||
def after_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
sandbox = state.get("sandbox")
|
||||
@@ -81,3 +108,21 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]):
|
||||
|
||||
# No sandbox to release
|
||||
return super().after_agent(state, runtime)
|
||||
|
||||
@override
|
||||
async def aafter_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
sandbox = state.get("sandbox")
|
||||
if sandbox is not None:
|
||||
sandbox_id = sandbox["sandbox_id"]
|
||||
logger.info(f"Releasing sandbox {sandbox_id}")
|
||||
await self._release_sandbox_async(sandbox_id)
|
||||
return None
|
||||
|
||||
if (runtime.context or {}).get("sandbox_id") is not None:
|
||||
sandbox_id = runtime.context.get("sandbox_id")
|
||||
logger.info(f"Releasing sandbox {sandbox_id} from context")
|
||||
await self._release_sandbox_async(sandbox_id)
|
||||
return None
|
||||
|
||||
# No sandbox to release
|
||||
return await super().aafter_agent(state, runtime)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
@@ -19,6 +20,16 @@ class SandboxProvider(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def acquire_async(self, thread_id: str | None = None) -> str:
|
||||
"""Acquire a sandbox without blocking the event loop.
|
||||
|
||||
Most sandbox providers expose a synchronous lifecycle API because local
|
||||
Docker/provisioner operations are blocking. Async runtimes should call
|
||||
this method so those blocking operations run in a worker thread instead
|
||||
of stalling the event loop.
|
||||
"""
|
||||
return await asyncio.to_thread(self.acquire, thread_id)
|
||||
|
||||
@abstractmethod
|
||||
def get(self, sandbox_id: str) -> Sandbox | None:
|
||||
"""Get a sandbox environment by ID.
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import posixpath
|
||||
import re
|
||||
import shlex
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
from langchain.tools import tool
|
||||
@@ -1111,6 +1113,68 @@ def ensure_sandbox_initialized(runtime: Runtime | None = None) -> Sandbox:
|
||||
return sandbox
|
||||
|
||||
|
||||
async def ensure_sandbox_initialized_async(runtime: Runtime | None = None) -> Sandbox:
|
||||
"""Async counterpart to ``ensure_sandbox_initialized`` for tool runtimes.
|
||||
|
||||
This keeps lazy sandbox acquisition on the async provider hook, so AIO
|
||||
sandbox startup and readiness polling do not fall back to synchronous
|
||||
``provider.acquire()`` during async tool execution.
|
||||
"""
|
||||
if runtime is None:
|
||||
raise SandboxRuntimeError("Tool runtime not available")
|
||||
|
||||
if runtime.state is None:
|
||||
raise SandboxRuntimeError("Tool runtime state not available")
|
||||
|
||||
sandbox_state = runtime.state.get("sandbox")
|
||||
if sandbox_state is not None:
|
||||
sandbox_id = sandbox_state.get("sandbox_id")
|
||||
if sandbox_id is not None:
|
||||
sandbox = get_sandbox_provider().get(sandbox_id)
|
||||
if sandbox is not None:
|
||||
if runtime.context is not None:
|
||||
runtime.context["sandbox_id"] = sandbox_id
|
||||
return sandbox
|
||||
|
||||
thread_id = runtime.context.get("thread_id") if runtime.context else None
|
||||
if thread_id is None:
|
||||
thread_id = runtime.config.get("configurable", {}).get("thread_id") if runtime.config else None
|
||||
if thread_id is None:
|
||||
raise SandboxRuntimeError("Thread ID not available in runtime context")
|
||||
|
||||
provider = get_sandbox_provider()
|
||||
sandbox_id = await provider.acquire_async(thread_id)
|
||||
|
||||
runtime.state["sandbox"] = {"sandbox_id": sandbox_id}
|
||||
|
||||
sandbox = provider.get(sandbox_id)
|
||||
if sandbox is None:
|
||||
raise SandboxNotFoundError("Sandbox not found after acquisition", sandbox_id=sandbox_id)
|
||||
|
||||
if runtime.context is not None:
|
||||
runtime.context["sandbox_id"] = sandbox_id
|
||||
return sandbox
|
||||
|
||||
|
||||
async def _run_sync_tool_after_async_sandbox_init(
|
||||
func: Callable[..., str] | None,
|
||||
runtime: Runtime,
|
||||
*args: object,
|
||||
) -> str:
|
||||
"""Initialize lazily via async provider, then run sync tool body off-thread."""
|
||||
try:
|
||||
await ensure_sandbox_initialized_async(runtime)
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error initializing sandbox: {_sanitize_error(e, runtime)}"
|
||||
|
||||
if func is None:
|
||||
return "Error: Tool implementation not available"
|
||||
|
||||
return await asyncio.to_thread(func, runtime, *args)
|
||||
|
||||
|
||||
def ensure_thread_directories_exist(runtime: Runtime | None) -> None:
|
||||
"""Ensure thread data directories (workspace, uploads, outputs) exist.
|
||||
|
||||
@@ -1273,6 +1337,13 @@ def bash_tool(runtime: Runtime, description: str, command: str) -> str:
|
||||
return f"Error: Unexpected error executing command: {_sanitize_error(e, runtime)}"
|
||||
|
||||
|
||||
async def _bash_tool_async(runtime: Runtime, description: str, command: str) -> str:
|
||||
return await _run_sync_tool_after_async_sandbox_init(bash_tool.func, runtime, description, command)
|
||||
|
||||
|
||||
bash_tool.coroutine = _bash_tool_async
|
||||
|
||||
|
||||
@tool("ls", parse_docstring=True)
|
||||
def ls_tool(runtime: Runtime, description: str, path: str) -> str:
|
||||
"""List the contents of a directory up to 2 levels deep in tree format.
|
||||
@@ -1320,6 +1391,13 @@ def ls_tool(runtime: Runtime, description: str, path: str) -> str:
|
||||
return f"Error: Unexpected error listing directory: {_sanitize_error(e, runtime)}"
|
||||
|
||||
|
||||
async def _ls_tool_async(runtime: Runtime, description: str, path: str) -> str:
|
||||
return await _run_sync_tool_after_async_sandbox_init(ls_tool.func, runtime, description, path)
|
||||
|
||||
|
||||
ls_tool.coroutine = _ls_tool_async
|
||||
|
||||
|
||||
@tool("glob", parse_docstring=True)
|
||||
def glob_tool(
|
||||
runtime: Runtime,
|
||||
@@ -1370,6 +1448,28 @@ def glob_tool(
|
||||
return f"Error: Unexpected error searching paths: {_sanitize_error(e, runtime)}"
|
||||
|
||||
|
||||
async def _glob_tool_async(
|
||||
runtime: Runtime,
|
||||
description: str,
|
||||
pattern: str,
|
||||
path: str,
|
||||
include_dirs: bool = False,
|
||||
max_results: int = _DEFAULT_GLOB_MAX_RESULTS,
|
||||
) -> str:
|
||||
return await _run_sync_tool_after_async_sandbox_init(
|
||||
glob_tool.func,
|
||||
runtime,
|
||||
description,
|
||||
pattern,
|
||||
path,
|
||||
include_dirs,
|
||||
max_results,
|
||||
)
|
||||
|
||||
|
||||
glob_tool.coroutine = _glob_tool_async
|
||||
|
||||
|
||||
@tool("grep", parse_docstring=True)
|
||||
def grep_tool(
|
||||
runtime: Runtime,
|
||||
@@ -1440,6 +1540,32 @@ def grep_tool(
|
||||
return f"Error: Unexpected error searching file contents: {_sanitize_error(e, runtime)}"
|
||||
|
||||
|
||||
async def _grep_tool_async(
|
||||
runtime: Runtime,
|
||||
description: str,
|
||||
pattern: str,
|
||||
path: str,
|
||||
glob: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = _DEFAULT_GREP_MAX_RESULTS,
|
||||
) -> str:
|
||||
return await _run_sync_tool_after_async_sandbox_init(
|
||||
grep_tool.func,
|
||||
runtime,
|
||||
description,
|
||||
pattern,
|
||||
path,
|
||||
glob,
|
||||
literal,
|
||||
case_sensitive,
|
||||
max_results,
|
||||
)
|
||||
|
||||
|
||||
grep_tool.coroutine = _grep_tool_async
|
||||
|
||||
|
||||
@tool("read_file", parse_docstring=True)
|
||||
def read_file_tool(
|
||||
runtime: Runtime,
|
||||
@@ -1495,6 +1621,19 @@ def read_file_tool(
|
||||
return f"Error: Unexpected error reading file: {_sanitize_error(e, runtime)}"
|
||||
|
||||
|
||||
async def _read_file_tool_async(
|
||||
runtime: Runtime,
|
||||
description: str,
|
||||
path: str,
|
||||
start_line: int | None = None,
|
||||
end_line: int | None = None,
|
||||
) -> str:
|
||||
return await _run_sync_tool_after_async_sandbox_init(read_file_tool.func, runtime, description, path, start_line, end_line)
|
||||
|
||||
|
||||
read_file_tool.coroutine = _read_file_tool_async
|
||||
|
||||
|
||||
@tool("write_file", parse_docstring=True)
|
||||
def write_file_tool(
|
||||
runtime: Runtime,
|
||||
@@ -1536,6 +1675,19 @@ def write_file_tool(
|
||||
return f"Error: Unexpected error writing file: {_sanitize_error(e, runtime)}"
|
||||
|
||||
|
||||
async def _write_file_tool_async(
|
||||
runtime: Runtime,
|
||||
description: str,
|
||||
path: str,
|
||||
content: str,
|
||||
append: bool = False,
|
||||
) -> str:
|
||||
return await _run_sync_tool_after_async_sandbox_init(write_file_tool.func, runtime, description, path, content, append)
|
||||
|
||||
|
||||
write_file_tool.coroutine = _write_file_tool_async
|
||||
|
||||
|
||||
@tool("str_replace", parse_docstring=True)
|
||||
def str_replace_tool(
|
||||
runtime: Runtime,
|
||||
@@ -1585,3 +1737,25 @@ def str_replace_tool(
|
||||
return f"Error: Permission denied accessing file: {requested_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error replacing string: {_sanitize_error(e, runtime)}"
|
||||
|
||||
|
||||
async def _str_replace_tool_async(
|
||||
runtime: Runtime,
|
||||
description: str,
|
||||
path: str,
|
||||
old_str: str,
|
||||
new_str: str,
|
||||
replace_all: bool = False,
|
||||
) -> str:
|
||||
return await _run_sync_tool_after_async_sandbox_init(
|
||||
str_replace_tool.func,
|
||||
runtime,
|
||||
description,
|
||||
path,
|
||||
old_str,
|
||||
new_str,
|
||||
replace_all,
|
||||
)
|
||||
|
||||
|
||||
str_replace_tool.coroutine = _str_replace_tool_async
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for AioSandboxProvider mount helpers."""
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
@@ -140,6 +141,182 @@ def test_discover_or_create_only_unlocks_when_lock_succeeds(tmp_path, monkeypatc
|
||||
assert unlock_calls == []
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_acquire_async_uses_async_readiness_polling(monkeypatch):
|
||||
"""AioSandboxProvider async creation must not use sync readiness polling."""
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
provider = _make_provider(None)
|
||||
provider._config = {"replicas": 3}
|
||||
provider._thread_locks = {}
|
||||
provider._warm_pool = {}
|
||||
provider._sandbox_infos = {}
|
||||
provider._thread_sandboxes = {}
|
||||
provider._last_activity = {}
|
||||
provider._lock = aio_mod.threading.Lock()
|
||||
provider._backend = SimpleNamespace(
|
||||
create=MagicMock(return_value=aio_mod.SandboxInfo(sandbox_id="sandbox-async", sandbox_url="http://sandbox")),
|
||||
destroy=MagicMock(),
|
||||
discover=MagicMock(return_value=None),
|
||||
)
|
||||
|
||||
async_readiness_calls: list[tuple[str, int]] = []
|
||||
|
||||
async def fake_wait_for_sandbox_ready_async(sandbox_url: str, timeout: int = 30, poll_interval: float = 1.0) -> bool:
|
||||
async_readiness_calls.append((sandbox_url, timeout))
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(aio_mod, "wait_for_sandbox_ready_async", fake_wait_for_sandbox_ready_async)
|
||||
monkeypatch.setattr(
|
||||
aio_mod,
|
||||
"wait_for_sandbox_ready",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("sync readiness should not be used")),
|
||||
)
|
||||
|
||||
sandbox_id = await provider._create_sandbox_async("thread-async", "sandbox-async")
|
||||
|
||||
assert sandbox_id == "sandbox-async"
|
||||
assert async_readiness_calls == [("http://sandbox", 60)]
|
||||
assert provider._backend.destroy.call_count == 0
|
||||
assert provider._thread_sandboxes["thread-async"] == "sandbox-async"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_discover_or_create_with_lock_async_offloads_lock_file_open_and_close(tmp_path, monkeypatch):
|
||||
"""Async lock path must not open or close lock files on the event loop."""
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
provider = _make_provider(tmp_path)
|
||||
provider._discover_or_create_with_lock_async = aio_mod.AioSandboxProvider._discover_or_create_with_lock_async.__get__(
|
||||
provider,
|
||||
aio_mod.AioSandboxProvider,
|
||||
)
|
||||
provider._thread_locks = {}
|
||||
provider._warm_pool = {}
|
||||
provider._sandbox_infos = {}
|
||||
provider._thread_sandboxes = {"thread-async-lock": "sandbox-async-lock"}
|
||||
provider._sandboxes = {"sandbox-async-lock": aio_mod.AioSandbox(id="sandbox-async-lock", base_url="http://sandbox")}
|
||||
provider._last_activity = {}
|
||||
provider._lock = aio_mod.threading.Lock()
|
||||
provider._backend = SimpleNamespace(discover=MagicMock(return_value=None))
|
||||
|
||||
monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path))
|
||||
|
||||
to_thread_calls: list[object] = []
|
||||
|
||||
async def fake_to_thread(func, /, *args, **kwargs):
|
||||
to_thread_calls.append(func)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(aio_mod.asyncio, "to_thread", fake_to_thread)
|
||||
|
||||
sandbox_id = await provider._discover_or_create_with_lock_async("thread-async-lock", "sandbox-async-lock")
|
||||
|
||||
assert sandbox_id == "sandbox-async-lock"
|
||||
assert aio_mod._open_lock_file in to_thread_calls
|
||||
assert any(getattr(func, "__name__", "") == "close" for func in to_thread_calls)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_acquire_thread_lock_async_uses_dedicated_executor(monkeypatch):
|
||||
"""Per-thread lock waits should not consume the default asyncio.to_thread pool."""
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
lock = aio_mod.threading.Lock()
|
||||
|
||||
async def fail_to_thread(*_args, **_kwargs):
|
||||
raise AssertionError("thread-lock acquisition must not use asyncio.to_thread")
|
||||
|
||||
monkeypatch.setattr(aio_mod.asyncio, "to_thread", fail_to_thread)
|
||||
|
||||
await aio_mod._acquire_thread_lock_async(lock)
|
||||
try:
|
||||
assert not lock.acquire(blocking=False)
|
||||
finally:
|
||||
lock.release()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_acquire_async_cancellation_does_not_leak_thread_lock(tmp_path):
|
||||
"""Cancelled async lock waiters must not leave the per-thread lock held."""
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
provider = _make_provider(tmp_path)
|
||||
provider._thread_locks = {}
|
||||
provider._warm_pool = {}
|
||||
provider._sandbox_infos = {}
|
||||
provider._thread_sandboxes = {}
|
||||
provider._last_activity = {}
|
||||
provider._lock = aio_mod.threading.Lock()
|
||||
|
||||
thread_id = "thread-cancel-lock"
|
||||
thread_lock = provider._get_thread_lock(thread_id)
|
||||
thread_lock.acquire()
|
||||
|
||||
task = asyncio.create_task(provider.acquire_async(thread_id))
|
||||
await asyncio.sleep(0.05)
|
||||
task.cancel()
|
||||
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
thread_lock.release()
|
||||
deadline = asyncio.get_running_loop().time() + 1
|
||||
while asyncio.get_running_loop().time() < deadline:
|
||||
acquired = thread_lock.acquire(blocking=False)
|
||||
if acquired:
|
||||
thread_lock.release()
|
||||
return
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
pytest.fail("provider thread lock was leaked after cancelling acquire_async")
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_acquire_async_cancelled_waiter_does_not_block_successor(tmp_path, monkeypatch):
|
||||
"""A cancelled waiter must not prevent the next live waiter from acquiring."""
|
||||
aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider")
|
||||
provider = _make_provider(tmp_path)
|
||||
provider._thread_locks = {}
|
||||
provider._warm_pool = {}
|
||||
provider._sandbox_infos = {}
|
||||
provider._thread_sandboxes = {}
|
||||
provider._last_activity = {}
|
||||
provider._lock = aio_mod.threading.Lock()
|
||||
|
||||
async def fake_acquire_internal_async(thread_id: str | None) -> str:
|
||||
assert thread_id == "thread-successor-lock"
|
||||
await asyncio.sleep(0)
|
||||
return "sandbox-successor"
|
||||
|
||||
monkeypatch.setattr(provider, "_acquire_internal_async", fake_acquire_internal_async)
|
||||
|
||||
thread_id = "thread-successor-lock"
|
||||
thread_lock = provider._get_thread_lock(thread_id)
|
||||
thread_lock.acquire()
|
||||
|
||||
cancelled_waiter = asyncio.create_task(provider.acquire_async(thread_id))
|
||||
await asyncio.sleep(0.05)
|
||||
cancelled_waiter.cancel()
|
||||
try:
|
||||
await cancelled_waiter
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
live_waiter = asyncio.create_task(provider.acquire_async(thread_id))
|
||||
thread_lock.release()
|
||||
|
||||
assert await asyncio.wait_for(live_waiter, timeout=1) == "sandbox-successor"
|
||||
|
||||
deadline = asyncio.get_running_loop().time() + 1
|
||||
while asyncio.get_running_loop().time() < deadline:
|
||||
acquired = thread_lock.acquire(blocking=False)
|
||||
if acquired:
|
||||
thread_lock.release()
|
||||
return
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
pytest.fail("provider thread lock was not released after successor acquire_async")
|
||||
|
||||
|
||||
def test_remote_backend_create_forwards_effective_user_id(monkeypatch):
|
||||
"""Provisioner mode must receive user_id so PVC subPath matches user isolation."""
|
||||
remote_mod = importlib.import_module("deerflow.community.aio_sandbox.remote_backend")
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.community.aio_sandbox import backend as readiness
|
||||
|
||||
|
||||
class _FakeAsyncClient:
|
||||
def __init__(self, *, responses: list[object], calls: list[str], timeout: float, request_timeouts: list[float] | None = None) -> None:
|
||||
self._responses = responses
|
||||
self._calls = calls
|
||||
self._timeout = timeout
|
||||
self._request_timeouts = request_timeouts
|
||||
|
||||
async def __aenter__(self) -> _FakeAsyncClient:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
return None
|
||||
|
||||
async def get(self, url: str, *, timeout: float):
|
||||
self._calls.append(url)
|
||||
if self._request_timeouts is not None:
|
||||
self._request_timeouts.append(timeout)
|
||||
response = self._responses.pop(0)
|
||||
if isinstance(response, BaseException):
|
||||
raise response
|
||||
return response
|
||||
|
||||
|
||||
class _FakeLoop:
|
||||
def __init__(self, times: list[float]) -> None:
|
||||
self._times = times
|
||||
self._index = 0
|
||||
|
||||
def time(self) -> float:
|
||||
value = self._times[self._index]
|
||||
self._index += 1
|
||||
return value
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_wait_for_sandbox_ready_async_uses_nonblocking_polling(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls: list[str] = []
|
||||
sleeps: list[float] = []
|
||||
|
||||
def fake_client(*, timeout: float):
|
||||
return _FakeAsyncClient(
|
||||
responses=[SimpleNamespace(status_code=503), SimpleNamespace(status_code=200)],
|
||||
calls=calls,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def fake_sleep(delay: float) -> None:
|
||||
sleeps.append(delay)
|
||||
|
||||
monkeypatch.setattr(readiness.httpx, "AsyncClient", fake_client)
|
||||
monkeypatch.setattr(readiness.asyncio, "sleep", fake_sleep)
|
||||
monkeypatch.setattr(readiness.requests, "get", lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("requests.get should not be used")))
|
||||
monkeypatch.setattr(readiness.time, "sleep", lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("time.sleep should not be used")))
|
||||
|
||||
assert await readiness.wait_for_sandbox_ready_async("http://sandbox", timeout=5, poll_interval=0.05) is True
|
||||
|
||||
assert calls == ["http://sandbox/v1/sandbox", "http://sandbox/v1/sandbox"]
|
||||
assert sleeps == [0.05]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_wait_for_sandbox_ready_async_retries_request_errors(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls: list[str] = []
|
||||
sleeps: list[float] = []
|
||||
|
||||
def fake_client(*, timeout: float):
|
||||
return _FakeAsyncClient(
|
||||
responses=[readiness.httpx.ConnectError("not ready"), SimpleNamespace(status_code=200)],
|
||||
calls=calls,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def fake_sleep(delay: float) -> None:
|
||||
sleeps.append(delay)
|
||||
|
||||
monkeypatch.setattr(readiness.httpx, "AsyncClient", fake_client)
|
||||
monkeypatch.setattr(readiness.asyncio, "sleep", fake_sleep)
|
||||
|
||||
assert await readiness.wait_for_sandbox_ready_async("http://sandbox", timeout=5, poll_interval=0.01) is True
|
||||
|
||||
assert len(calls) == 2
|
||||
assert sleeps == [0.01]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_wait_for_sandbox_ready_async_clamps_request_and_sleep_to_deadline(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls: list[str] = []
|
||||
request_timeouts: list[float] = []
|
||||
sleeps: list[float] = []
|
||||
|
||||
def fake_client(*, timeout: float):
|
||||
return _FakeAsyncClient(
|
||||
responses=[SimpleNamespace(status_code=503)],
|
||||
calls=calls,
|
||||
timeout=timeout,
|
||||
request_timeouts=request_timeouts,
|
||||
)
|
||||
|
||||
async def fake_sleep(delay: float) -> None:
|
||||
sleeps.append(delay)
|
||||
|
||||
monkeypatch.setattr(readiness.httpx, "AsyncClient", fake_client)
|
||||
monkeypatch.setattr(readiness.asyncio, "sleep", fake_sleep)
|
||||
monkeypatch.setattr(readiness.asyncio, "get_running_loop", lambda: _FakeLoop([100.0, 100.5, 101.75, 102.0]))
|
||||
|
||||
assert await readiness.wait_for_sandbox_ready_async("http://sandbox", timeout=2, poll_interval=1.0) is False
|
||||
|
||||
assert calls == ["http://sandbox/v1/sandbox"]
|
||||
assert request_timeouts == [1.5]
|
||||
assert sleeps == [0.25]
|
||||
@@ -159,6 +159,26 @@ def test_provisioner_create_returns_sandbox_info(monkeypatch):
|
||||
assert info.sandbox_url == "http://k3s:31001"
|
||||
|
||||
|
||||
def test_provisioner_create_accepts_anonymous_thread_id(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
def mock_post(url: str, json: dict, timeout: int):
|
||||
assert url == "http://provisioner:8002/api/sandboxes"
|
||||
assert json == {
|
||||
"sandbox_id": "anon123",
|
||||
"thread_id": None,
|
||||
"user_id": "test-user-autouse",
|
||||
}
|
||||
assert timeout == 30
|
||||
return _StubResponse(payload={"sandbox_id": "anon123", "sandbox_url": "http://k3s:31002"})
|
||||
|
||||
monkeypatch.setattr(requests, "post", mock_post)
|
||||
|
||||
info = backend.create(None, "anon123")
|
||||
assert info.sandbox_id == "anon123"
|
||||
assert info.sandbox_url == "http://k3s:31002"
|
||||
|
||||
|
||||
def test_provisioner_create_raises_runtime_error_on_request_exception(monkeypatch):
|
||||
backend = RemoteSandboxBackend("http://provisioner:8002")
|
||||
|
||||
|
||||
@@ -0,0 +1,225 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.tools import ToolRuntime
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.sandbox.middleware import SandboxMiddleware
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.sandbox_provider import SandboxProvider, reset_sandbox_provider, set_sandbox_provider
|
||||
from deerflow.sandbox.search import GrepMatch
|
||||
from deerflow.sandbox.tools import ls_tool
|
||||
|
||||
|
||||
class _SyncProvider(SandboxProvider):
|
||||
def __init__(self) -> None:
|
||||
self.thread_ids: list[str | None] = []
|
||||
|
||||
def acquire(self, thread_id: str | None = None) -> str:
|
||||
self.thread_ids.append(thread_id)
|
||||
return "sync-sandbox"
|
||||
|
||||
def get(self, sandbox_id: str) -> Sandbox | None:
|
||||
return None
|
||||
|
||||
def release(self, sandbox_id: str) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _SandboxStub(Sandbox):
|
||||
def execute_command(self, command: str) -> str:
|
||||
return "OK"
|
||||
|
||||
def read_file(self, path: str) -> str:
|
||||
return "content"
|
||||
|
||||
def download_file(self, path: str) -> bytes:
|
||||
return b"content"
|
||||
|
||||
def list_dir(self, path: str, max_depth: int = 2) -> list[str]:
|
||||
return ["/mnt/user-data/workspace/file.txt"]
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
return None
|
||||
|
||||
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
|
||||
return [], False
|
||||
|
||||
def grep(
|
||||
self,
|
||||
path: str,
|
||||
pattern: str,
|
||||
*,
|
||||
glob: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = 100,
|
||||
) -> tuple[list[GrepMatch], bool]:
|
||||
return [], False
|
||||
|
||||
def update_file(self, path: str, content: bytes) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _AsyncOnlyProvider(SandboxProvider):
|
||||
def __init__(self) -> None:
|
||||
self.thread_ids: list[str | None] = []
|
||||
self.released_ids: list[str] = []
|
||||
self.sandbox = _SandboxStub("async-sandbox")
|
||||
|
||||
def acquire(self, thread_id: str | None = None) -> str:
|
||||
raise AssertionError("async middleware should not call sync acquire")
|
||||
|
||||
async def acquire_async(self, thread_id: str | None = None) -> str:
|
||||
self.thread_ids.append(thread_id)
|
||||
return "async-sandbox"
|
||||
|
||||
def get(self, sandbox_id: str) -> Sandbox | None:
|
||||
if sandbox_id == "async-sandbox":
|
||||
return self.sandbox
|
||||
return None
|
||||
|
||||
def release(self, sandbox_id: str) -> None:
|
||||
self.released_ids.append(sandbox_id)
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_provider_default_acquire_async_offloads_sync_acquire(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
provider = _SyncProvider()
|
||||
calls: list[tuple[object, tuple[object, ...]]] = []
|
||||
|
||||
async def fake_to_thread(func, /, *args):
|
||||
calls.append((func, args))
|
||||
return func(*args)
|
||||
|
||||
monkeypatch.setattr(asyncio, "to_thread", fake_to_thread)
|
||||
|
||||
sandbox_id = await provider.acquire_async("thread-1")
|
||||
|
||||
assert sandbox_id == "sync-sandbox"
|
||||
assert provider.thread_ids == ["thread-1"]
|
||||
assert calls == [(provider.acquire, ("thread-1",))]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_abefore_agent_uses_async_provider_acquire() -> None:
|
||||
provider = _AsyncOnlyProvider()
|
||||
set_sandbox_provider(provider)
|
||||
try:
|
||||
middleware = SandboxMiddleware(lazy_init=False)
|
||||
|
||||
result = await middleware.abefore_agent({}, Runtime(context={"thread_id": "thread-2"}))
|
||||
finally:
|
||||
reset_sandbox_provider()
|
||||
|
||||
assert result == {"sandbox": {"sandbox_id": "async-sandbox"}}
|
||||
assert provider.thread_ids == ["thread-2"]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize(
|
||||
("middleware", "state", "runtime"),
|
||||
[
|
||||
(SandboxMiddleware(lazy_init=True), {}, Runtime(context={"thread_id": "thread-lazy"})),
|
||||
(SandboxMiddleware(lazy_init=False), {}, Runtime(context={})),
|
||||
(SandboxMiddleware(lazy_init=False), {"sandbox": {"sandbox_id": "existing"}}, Runtime(context={"thread_id": "thread-existing"})),
|
||||
],
|
||||
)
|
||||
async def test_abefore_agent_delegates_to_super_when_not_acquiring(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
middleware: SandboxMiddleware,
|
||||
state: dict,
|
||||
runtime: Runtime,
|
||||
) -> None:
|
||||
calls: list[tuple[dict, Runtime]] = []
|
||||
|
||||
async def fake_super_abefore_agent(self, state_arg, runtime_arg):
|
||||
calls.append((state_arg, runtime_arg))
|
||||
return {"delegated": True}
|
||||
|
||||
monkeypatch.setattr(AgentMiddleware, "abefore_agent", fake_super_abefore_agent)
|
||||
|
||||
result = await middleware.abefore_agent(state, runtime)
|
||||
|
||||
assert result == {"delegated": True}
|
||||
assert calls == [(state, runtime)]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_default_lazy_tool_acquisition_uses_async_provider() -> None:
|
||||
provider = _AsyncOnlyProvider()
|
||||
set_sandbox_provider(provider)
|
||||
try:
|
||||
runtime = ToolRuntime(
|
||||
state={},
|
||||
context={"thread_id": "thread-lazy"},
|
||||
config={"configurable": {}},
|
||||
stream_writer=lambda _: None,
|
||||
tools=[],
|
||||
tool_call_id="call-1",
|
||||
store=None,
|
||||
)
|
||||
|
||||
result = await ls_tool.ainvoke({"runtime": runtime, "description": "list workspace", "path": "/mnt/user-data/workspace"})
|
||||
finally:
|
||||
reset_sandbox_provider()
|
||||
|
||||
assert result == "/mnt/user-data/workspace/file.txt"
|
||||
assert provider.thread_ids == ["thread-lazy"]
|
||||
assert runtime.state["sandbox"] == {"sandbox_id": "async-sandbox"}
|
||||
assert runtime.context["sandbox_id"] == "async-sandbox"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize(
|
||||
("state", "runtime", "expected_sandbox_id"),
|
||||
[
|
||||
({"sandbox": {"sandbox_id": "state-sandbox"}}, Runtime(context={}), "state-sandbox"),
|
||||
({}, Runtime(context={"sandbox_id": "context-sandbox"}), "context-sandbox"),
|
||||
],
|
||||
)
|
||||
async def test_aafter_agent_releases_sandbox_off_thread(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
state: dict,
|
||||
runtime: Runtime,
|
||||
expected_sandbox_id: str,
|
||||
) -> None:
|
||||
provider = _AsyncOnlyProvider()
|
||||
to_thread_calls: list[tuple[object, tuple[object, ...]]] = []
|
||||
|
||||
async def fake_to_thread(func, /, *args):
|
||||
to_thread_calls.append((func, args))
|
||||
return func(*args)
|
||||
|
||||
monkeypatch.setattr(asyncio, "to_thread", fake_to_thread)
|
||||
set_sandbox_provider(provider)
|
||||
try:
|
||||
result = await SandboxMiddleware().aafter_agent(state, runtime)
|
||||
finally:
|
||||
reset_sandbox_provider()
|
||||
|
||||
assert result is None
|
||||
assert provider.released_ids == [expected_sandbox_id]
|
||||
assert to_thread_calls == [(provider.release, (expected_sandbox_id,))]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_aafter_agent_delegates_to_super_when_no_sandbox(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
calls: list[tuple[dict, Runtime]] = []
|
||||
|
||||
async def fake_super_aafter_agent(self, state_arg, runtime_arg):
|
||||
calls.append((state_arg, runtime_arg))
|
||||
return {"delegated": True}
|
||||
|
||||
monkeypatch.setattr(AgentMiddleware, "aafter_agent", fake_super_aafter_agent)
|
||||
|
||||
state = {}
|
||||
runtime = Runtime(context={})
|
||||
result = await SandboxMiddleware().aafter_agent(state, runtime)
|
||||
|
||||
assert result == {"delegated": True}
|
||||
assert calls == [(state, runtime)]
|
||||
Reference in New Issue
Block a user