From 8b697245ebe835bedc4e386dbfb15c422e2f6e11 Mon Sep 17 00:00:00 2001 From: AochenShen99 <142667174+ShenAC-SAC@users.noreply.github.com> Date: Thu, 21 May 2026 14:44:34 +0800 Subject: [PATCH] fix(sandbox): avoid blocking sandbox readiness polling (#2822) * fix(sandbox): offload async sandbox acquisition Run blocking sandbox provider acquisition through the async provider hook so eager sandbox setup does not stall the event loop. * fix(sandbox): add async readiness polling Introduce an async sandbox readiness poller using httpx and asyncio.sleep while preserving the existing synchronous API. * test(sandbox): cover async readiness polling Lock in non-blocking readiness behavior so the async helper does not regress to requests.get or time.sleep. * fix(sandbox): allow anonymous backend creation * fix(sandbox): use async readiness in provider acquisition * fix(sandbox): use async acquisition for lazy tools * test(sandbox): cover anonymous remote creation * fix(sandbox): clamp async readiness timeout budget * fix(sandbox): offload async lock file handling * fix(sandbox): delegate async middleware fallthrough * docs(sandbox): document async acquisition path * fix(sandbox): offload async sandbox release * docs(sandbox): mention async release hook * fix(sandbox): address async lock review Reduce duplicate sync/async sandbox acquisition state handling and move async thread-lock waits onto a dedicated executor with cancellation-safe cleanup. * chore: retrigger ci Retrigger GitHub Actions after upstream main fixed the stale PR merge lint failure. * test(sandbox): sync backend unit fixtures --------- Co-authored-by: Willem Jiang --- backend/CLAUDE.md | 2 +- backend/README.md | 2 +- .../aio_sandbox/aio_sandbox_provider.py | 297 ++++++++++++++---- .../deerflow/community/aio_sandbox/backend.py | 32 +- .../community/aio_sandbox/local_backend.py | 2 +- .../community/aio_sandbox/remote_backend.py | 4 +- .../harness/deerflow/sandbox/middleware.py | 45 +++ .../deerflow/sandbox/sandbox_provider.py | 11 + .../harness/deerflow/sandbox/tools.py | 174 ++++++++++ backend/tests/test_aio_sandbox_provider.py | 177 +++++++++++ backend/tests/test_aio_sandbox_readiness.py | 119 +++++++ backend/tests/test_remote_sandbox_backend.py | 20 ++ backend/tests/test_sandbox_middleware.py | 225 +++++++++++++ 13 files changed, 1037 insertions(+), 73 deletions(-) create mode 100644 backend/tests/test_aio_sandbox_readiness.py create mode 100644 backend/tests/test_sandbox_middleware.py diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index b951f919c..886b82dcb 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -236,7 +236,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti ### Sandbox System (`packages/harness/deerflow/sandbox/`) **Interface**: Abstract `Sandbox` with `execute_command`, `read_file`, `write_file`, `list_dir` -**Provider Pattern**: `SandboxProvider` with `acquire`, `get`, `release` lifecycle +**Provider Pattern**: `SandboxProvider` with `acquire`, `acquire_async`, `get`, `release` lifecycle. Async agent/tool paths call async sandbox lifecycle hooks so Docker sandbox creation, discovery, cross-process locking, readiness polling, and release stay off the event loop. **Implementations**: - `LocalSandboxProvider` - Local filesystem execution. `acquire(thread_id)` returns a per-thread `LocalSandbox` (id `local:{thread_id}`) whose `path_mappings` resolve `/mnt/user-data/{workspace,uploads,outputs}` and `/mnt/acp-workspace` to that thread's host directories, so the public `Sandbox` API honours the `/mnt/user-data` contract uniformly with AIO. `acquire()` / `acquire(None)` keeps the legacy generic singleton (id `local`) for callers without a thread context. Per-thread sandboxes are held in an LRU cache (default 256 entries) guarded by a `threading.Lock`. - `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation diff --git a/backend/README.md b/backend/README.md index 8c61e2db2..0ee0d454b 100644 --- a/backend/README.md +++ b/backend/README.md @@ -69,7 +69,7 @@ Middlewares execute in strict order, each handling a specific concern: Per-thread isolated execution with virtual path translation: - **Abstract interface**: `execute_command`, `read_file`, `write_file`, `list_dir` -- **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/) +- **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/). Async runtime paths use async sandbox lifecycle hooks so startup, readiness polling, and release do not block the event loop. - **Virtual paths**: `/mnt/user-data/{workspace,uploads,outputs}` → thread-specific physical directories - **Skills path**: `/mnt/skills` → `deer-flow/skills/` directory - **Skills loading**: Recursively discovers nested `SKILL.md` files under `skills/{public,custom}` and preserves nested container paths diff --git a/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox_provider.py b/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox_provider.py index 292a43758..4d7e16cab 100644 --- a/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox_provider.py +++ b/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox_provider.py @@ -10,6 +10,7 @@ The provider itself handles: - Mount computation (thread-specific, skills) """ +import asyncio import atexit import hashlib import logging @@ -18,6 +19,7 @@ import signal import threading import time import uuid +from concurrent.futures import ThreadPoolExecutor try: import fcntl @@ -32,7 +34,7 @@ from deerflow.sandbox.sandbox import Sandbox from deerflow.sandbox.sandbox_provider import SandboxProvider from .aio_sandbox import AioSandbox -from .backend import SandboxBackend, wait_for_sandbox_ready +from .backend import SandboxBackend, wait_for_sandbox_ready, wait_for_sandbox_ready_async from .local_backend import LocalContainerBackend from .remote_backend import RemoteSandboxBackend from .sandbox_info import SandboxInfo @@ -46,6 +48,9 @@ DEFAULT_CONTAINER_PREFIX = "deer-flow-sandbox" DEFAULT_IDLE_TIMEOUT = 600 # 10 minutes in seconds DEFAULT_REPLICAS = 3 # Maximum concurrent sandbox containers IDLE_CHECK_INTERVAL = 60 # Check every 60 seconds +THREAD_LOCK_EXECUTOR_WORKERS = min(32, (os.cpu_count() or 1) + 4) +_THREAD_LOCK_EXECUTOR = ThreadPoolExecutor(max_workers=THREAD_LOCK_EXECUTOR_WORKERS, thread_name_prefix="sandbox-lock-wait") +atexit.register(_THREAD_LOCK_EXECUTOR.shutdown, wait=False, cancel_futures=True) def _lock_file_exclusive(lock_file) -> None: @@ -66,6 +71,40 @@ def _unlock_file(lock_file) -> None: msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1) +def _open_lock_file(lock_path): + return open(lock_path, "a", encoding="utf-8") + + +async def _acquire_thread_lock_async(lock: threading.Lock) -> None: + """Acquire a threading.Lock without polling or using the default executor.""" + loop = asyncio.get_running_loop() + acquire_future = loop.run_in_executor(_THREAD_LOCK_EXECUTOR, lock.acquire, True) + + try: + acquired = await asyncio.shield(acquire_future) + except asyncio.CancelledError: + acquire_future.add_done_callback(lambda task: _release_cancelled_lock_acquire(lock, task)) + raise + + if not acquired: + raise RuntimeError("Failed to acquire sandbox thread lock") + + +def _release_cancelled_lock_acquire(lock: threading.Lock, task: asyncio.Future[bool]) -> None: + """Release a lock acquired after its awaiting coroutine was cancelled.""" + if task.cancelled(): + return + + try: + acquired = task.result() + except Exception as e: + logger.warning(f"Cancelled sandbox lock acquisition finished with error: {e}") + return + + if acquired: + lock.release() + + class AioSandboxProvider(SandboxProvider): """Sandbox provider that manages containers running the AIO sandbox. @@ -416,6 +455,96 @@ class AioSandboxProvider(SandboxProvider): self._thread_locks[thread_id] = threading.Lock() return self._thread_locks[thread_id] + def _sandbox_id_for_thread(self, thread_id: str | None) -> str: + """Return deterministic IDs for thread sandboxes and random IDs otherwise.""" + return self._deterministic_sandbox_id(thread_id) if thread_id else str(uuid.uuid4())[:8] + + def _reuse_in_process_sandbox(self, thread_id: str | None, *, post_lock: bool = False) -> str | None: + """Reuse an active in-process sandbox for a thread if one is still tracked.""" + if thread_id is None: + return None + + with self._lock: + if thread_id not in self._thread_sandboxes: + return None + + existing_id = self._thread_sandboxes[thread_id] + if existing_id in self._sandboxes: + suffix = " (post-lock check)" if post_lock else "" + logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}{suffix}") + self._last_activity[existing_id] = time.time() + return existing_id + + del self._thread_sandboxes[thread_id] + return None + + def _reclaim_warm_pool_sandbox(self, thread_id: str | None, sandbox_id: str, *, post_lock: bool = False) -> str | None: + """Promote a warm-pool sandbox back to active tracking if available.""" + if thread_id is None: + return None + + with self._lock: + if sandbox_id not in self._warm_pool: + return None + + info, _ = self._warm_pool.pop(sandbox_id) + sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url) + self._sandboxes[sandbox_id] = sandbox + self._sandbox_infos[sandbox_id] = info + self._last_activity[sandbox_id] = time.time() + self._thread_sandboxes[thread_id] = sandbox_id + + suffix = " (post-lock check)" if post_lock else f" at {info.sandbox_url}" + logger.info(f"Reclaimed warm-pool sandbox {sandbox_id} for thread {thread_id}{suffix}") + return sandbox_id + + def _recheck_cached_sandbox(self, thread_id: str, sandbox_id: str) -> str | None: + """Re-check in-memory caches after acquiring the cross-process file lock.""" + return self._reuse_in_process_sandbox(thread_id, post_lock=True) or self._reclaim_warm_pool_sandbox(thread_id, sandbox_id, post_lock=True) + + def _register_discovered_sandbox(self, thread_id: str, info: SandboxInfo) -> str: + """Track a sandbox discovered through the backend.""" + sandbox = AioSandbox(id=info.sandbox_id, base_url=info.sandbox_url) + with self._lock: + self._sandboxes[info.sandbox_id] = sandbox + self._sandbox_infos[info.sandbox_id] = info + self._last_activity[info.sandbox_id] = time.time() + self._thread_sandboxes[thread_id] = info.sandbox_id + + logger.info(f"Discovered existing sandbox {info.sandbox_id} for thread {thread_id} at {info.sandbox_url}") + return info.sandbox_id + + def _register_created_sandbox(self, thread_id: str | None, sandbox_id: str, info: SandboxInfo) -> str: + """Track a newly-created sandbox in the active maps.""" + sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url) + with self._lock: + self._sandboxes[sandbox_id] = sandbox + self._sandbox_infos[sandbox_id] = info + self._last_activity[sandbox_id] = time.time() + if thread_id: + self._thread_sandboxes[thread_id] = sandbox_id + + logger.info(f"Created sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}") + return sandbox_id + + def _replica_count(self) -> tuple[int, int]: + """Return configured replicas and currently tracked sandbox count.""" + replicas = self._config.get("replicas", DEFAULT_REPLICAS) + with self._lock: + total = len(self._sandboxes) + len(self._warm_pool) + return replicas, total + + def _log_replicas_soft_cap(self, replicas: int, sandbox_id: str, evicted: str | None) -> None: + """Log the result of enforcing the warm-pool replica budget.""" + if evicted: + logger.info(f"Evicted warm-pool sandbox {evicted} to stay within replicas={replicas}") + return + + # All slots are occupied by active sandboxes — proceed anyway and log. + # The replicas limit is a soft cap; we never forcibly stop a container + # that is actively serving a thread. + logger.warning(f"All {replicas} replica slots are in active use; creating sandbox {sandbox_id} beyond the soft limit") + # ── Core: acquire / get / release / shutdown ───────────────────────── def acquire(self, thread_id: str | None = None) -> str: @@ -440,6 +569,23 @@ class AioSandboxProvider(SandboxProvider): else: return self._acquire_internal(thread_id) + async def acquire_async(self, thread_id: str | None = None) -> str: + """Acquire a sandbox environment without blocking the event loop. + + Mirrors ``acquire()`` while keeping blocking backend operations off the + event loop and using async-native readiness polling for newly created + sandboxes. + """ + if thread_id: + thread_lock = self._get_thread_lock(thread_id) + await _acquire_thread_lock_async(thread_lock) + try: + return await self._acquire_internal_async(thread_id) + finally: + thread_lock.release() + + return await self._acquire_internal_async(thread_id) + def _acquire_internal(self, thread_id: str | None) -> str: """Internal sandbox acquisition with two-layer consistency. @@ -448,33 +594,17 @@ class AioSandboxProvider(SandboxProvider): sandbox_id is deterministic from thread_id so no shared state file is needed — any process can derive the same container name) """ - # ── Layer 1: In-process cache (fast path) ── - if thread_id: - with self._lock: - if thread_id in self._thread_sandboxes: - existing_id = self._thread_sandboxes[thread_id] - if existing_id in self._sandboxes: - logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}") - self._last_activity[existing_id] = time.time() - return existing_id - else: - del self._thread_sandboxes[thread_id] + cached_id = self._reuse_in_process_sandbox(thread_id) + if cached_id is not None: + return cached_id # Deterministic ID for thread-specific, random for anonymous - sandbox_id = self._deterministic_sandbox_id(thread_id) if thread_id else str(uuid.uuid4())[:8] + sandbox_id = self._sandbox_id_for_thread(thread_id) # ── Layer 1.5: Warm pool (container still running, no cold-start) ── - if thread_id: - with self._lock: - if sandbox_id in self._warm_pool: - info, _ = self._warm_pool.pop(sandbox_id) - sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url) - self._sandboxes[sandbox_id] = sandbox - self._sandbox_infos[sandbox_id] = info - self._last_activity[sandbox_id] = time.time() - self._thread_sandboxes[thread_id] = sandbox_id - logger.info(f"Reclaimed warm-pool sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}") - return sandbox_id + reclaimed_id = self._reclaim_warm_pool_sandbox(thread_id, sandbox_id) + if reclaimed_id is not None: + return reclaimed_id # ── Layer 2: Backend discovery + create (protected by cross-process lock) ── # Use a file lock so that two processes racing to create the same sandbox @@ -485,6 +615,26 @@ class AioSandboxProvider(SandboxProvider): return self._create_sandbox(thread_id, sandbox_id) + async def _acquire_internal_async(self, thread_id: str | None) -> str: + """Async counterpart to ``_acquire_internal``.""" + cached_id = self._reuse_in_process_sandbox(thread_id) + if cached_id is not None: + return cached_id + + # Deterministic ID for thread-specific, random for anonymous + sandbox_id = self._sandbox_id_for_thread(thread_id) + + # ── Layer 1.5: Warm pool (container still running, no cold-start) ── + reclaimed_id = self._reclaim_warm_pool_sandbox(thread_id, sandbox_id) + if reclaimed_id is not None: + return reclaimed_id + + # ── Layer 2: Backend discovery + create (protected by cross-process lock) ── + if thread_id: + return await self._discover_or_create_with_lock_async(thread_id, sandbox_id) + + return await self._create_sandbox_async(thread_id, sandbox_id) + def _discover_or_create_with_lock(self, thread_id: str, sandbox_id: str) -> str: """Discover an existing sandbox or create a new one under a cross-process file lock. @@ -503,40 +653,50 @@ class AioSandboxProvider(SandboxProvider): locked = True # Re-check in-process caches under the file lock in case another # thread in this process won the race while we were waiting. - with self._lock: - if thread_id in self._thread_sandboxes: - existing_id = self._thread_sandboxes[thread_id] - if existing_id in self._sandboxes: - logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id} (post-lock check)") - self._last_activity[existing_id] = time.time() - return existing_id - if sandbox_id in self._warm_pool: - info, _ = self._warm_pool.pop(sandbox_id) - sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url) - self._sandboxes[sandbox_id] = sandbox - self._sandbox_infos[sandbox_id] = info - self._last_activity[sandbox_id] = time.time() - self._thread_sandboxes[thread_id] = sandbox_id - logger.info(f"Reclaimed warm-pool sandbox {sandbox_id} for thread {thread_id} (post-lock check)") - return sandbox_id + cached_id = self._recheck_cached_sandbox(thread_id, sandbox_id) + if cached_id is not None: + return cached_id # Backend discovery: another process may have created the container. discovered = self._backend.discover(sandbox_id) if discovered is not None: - sandbox = AioSandbox(id=discovered.sandbox_id, base_url=discovered.sandbox_url) - with self._lock: - self._sandboxes[discovered.sandbox_id] = sandbox - self._sandbox_infos[discovered.sandbox_id] = discovered - self._last_activity[discovered.sandbox_id] = time.time() - self._thread_sandboxes[thread_id] = discovered.sandbox_id - logger.info(f"Discovered existing sandbox {discovered.sandbox_id} for thread {thread_id} at {discovered.sandbox_url}") - return discovered.sandbox_id + return self._register_discovered_sandbox(thread_id, discovered) return self._create_sandbox(thread_id, sandbox_id) finally: if locked: _unlock_file(lock_file) + async def _discover_or_create_with_lock_async(self, thread_id: str, sandbox_id: str) -> str: + """Async counterpart to ``_discover_or_create_with_lock``.""" + paths = get_paths() + user_id = get_effective_user_id() + await asyncio.to_thread(paths.ensure_thread_dirs, thread_id, user_id=user_id) + lock_path = paths.thread_dir(thread_id, user_id=user_id) / f"{sandbox_id}.lock" + + lock_file = await asyncio.to_thread(_open_lock_file, lock_path) + locked = False + try: + await asyncio.to_thread(_lock_file_exclusive, lock_file) + locked = True + # Re-check in-process caches under the file lock in case another + # thread in this process won the race while we were waiting. + cached_id = self._recheck_cached_sandbox(thread_id, sandbox_id) + if cached_id is not None: + return cached_id + + # Backend discovery is sync because local discovery may inspect + # Docker and perform a health check; keep it off the event loop. + discovered = await asyncio.to_thread(self._backend.discover, sandbox_id) + if discovered is not None: + return self._register_discovered_sandbox(thread_id, discovered) + + return await self._create_sandbox_async(thread_id, sandbox_id) + finally: + if locked: + await asyncio.to_thread(_unlock_file, lock_file) + await asyncio.to_thread(lock_file.close) + def _evict_oldest_warm(self) -> str | None: """Destroy the oldest container in the warm pool to free capacity. @@ -574,18 +734,10 @@ class AioSandboxProvider(SandboxProvider): # Enforce replicas: only warm-pool containers count toward eviction budget. # Active sandboxes are in use by live threads and must not be forcibly stopped. - replicas = self._config.get("replicas", DEFAULT_REPLICAS) - with self._lock: - total = len(self._sandboxes) + len(self._warm_pool) + replicas, total = self._replica_count() if total >= replicas: evicted = self._evict_oldest_warm() - if evicted: - logger.info(f"Evicted warm-pool sandbox {evicted} to stay within replicas={replicas}") - else: - # All slots are occupied by active sandboxes — proceed anyway and log. - # The replicas limit is a soft cap; we never forcibly stop a container - # that is actively serving a thread. - logger.warning(f"All {replicas} replica slots are in active use; creating sandbox {sandbox_id} beyond the soft limit") + self._log_replicas_soft_cap(replicas, sandbox_id, evicted) info = self._backend.create(thread_id, sandbox_id, extra_mounts=extra_mounts or None) @@ -594,16 +746,27 @@ class AioSandboxProvider(SandboxProvider): self._backend.destroy(info) raise RuntimeError(f"Sandbox {sandbox_id} failed to become ready within timeout at {info.sandbox_url}") - sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url) - with self._lock: - self._sandboxes[sandbox_id] = sandbox - self._sandbox_infos[sandbox_id] = info - self._last_activity[sandbox_id] = time.time() - if thread_id: - self._thread_sandboxes[thread_id] = sandbox_id + return self._register_created_sandbox(thread_id, sandbox_id, info) - logger.info(f"Created sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}") - return sandbox_id + async def _create_sandbox_async(self, thread_id: str | None, sandbox_id: str) -> str: + """Async counterpart to ``_create_sandbox``.""" + extra_mounts = await asyncio.to_thread(self._get_extra_mounts, thread_id) + + # Enforce replicas: only warm-pool containers count toward eviction budget. + # Active sandboxes are in use by live threads and must not be forcibly stopped. + replicas, total = self._replica_count() + if total >= replicas: + evicted = await asyncio.to_thread(self._evict_oldest_warm) + self._log_replicas_soft_cap(replicas, sandbox_id, evicted) + + info = await asyncio.to_thread(self._backend.create, thread_id, sandbox_id, extra_mounts=extra_mounts or None) + + # Wait for sandbox to be ready without blocking the event loop. + if not await wait_for_sandbox_ready_async(info.sandbox_url, timeout=60): + await asyncio.to_thread(self._backend.destroy, info) + raise RuntimeError(f"Sandbox {sandbox_id} failed to become ready within timeout at {info.sandbox_url}") + + return self._register_created_sandbox(thread_id, sandbox_id, info) def get(self, sandbox_id: str) -> Sandbox | None: """Get a sandbox by ID. Updates last activity timestamp. diff --git a/backend/packages/harness/deerflow/community/aio_sandbox/backend.py b/backend/packages/harness/deerflow/community/aio_sandbox/backend.py index 0200ba783..a1db1bf31 100644 --- a/backend/packages/harness/deerflow/community/aio_sandbox/backend.py +++ b/backend/packages/harness/deerflow/community/aio_sandbox/backend.py @@ -2,10 +2,12 @@ from __future__ import annotations +import asyncio import logging import time from abc import ABC, abstractmethod +import httpx import requests from .sandbox_info import SandboxInfo @@ -35,6 +37,34 @@ def wait_for_sandbox_ready(sandbox_url: str, timeout: int = 30) -> bool: return False +async def wait_for_sandbox_ready_async(sandbox_url: str, timeout: int = 30, poll_interval: float = 1.0) -> bool: + """Async variant of sandbox readiness polling. + + Use this from async runtime paths so sandbox startup waits do not block the + event loop. The synchronous ``wait_for_sandbox_ready`` function remains for + existing synchronous backend/provider call sites. + """ + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout + + async with httpx.AsyncClient(timeout=5) as client: + while True: + remaining = deadline - loop.time() + if remaining <= 0: + break + try: + response = await client.get(f"{sandbox_url}/v1/sandbox", timeout=min(5.0, remaining)) + if response.status_code == 200: + return True + except httpx.RequestError: + pass + remaining = deadline - loop.time() + if remaining <= 0: + break + await asyncio.sleep(min(poll_interval, remaining)) + return False + + class SandboxBackend(ABC): """Abstract base for sandbox provisioning backends. @@ -44,7 +74,7 @@ class SandboxBackend(ABC): """ @abstractmethod - def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo: + def create(self, thread_id: str | None, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo: """Create/provision a new sandbox. Args: diff --git a/backend/packages/harness/deerflow/community/aio_sandbox/local_backend.py b/backend/packages/harness/deerflow/community/aio_sandbox/local_backend.py index 92d933d89..69d838208 100644 --- a/backend/packages/harness/deerflow/community/aio_sandbox/local_backend.py +++ b/backend/packages/harness/deerflow/community/aio_sandbox/local_backend.py @@ -241,7 +241,7 @@ class LocalContainerBackend(SandboxBackend): # ── SandboxBackend interface ────────────────────────────────────────── - def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo: + def create(self, thread_id: str | None, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo: """Start a new container and return its connection info. Args: diff --git a/backend/packages/harness/deerflow/community/aio_sandbox/remote_backend.py b/backend/packages/harness/deerflow/community/aio_sandbox/remote_backend.py index 9b23e05dc..83925df13 100644 --- a/backend/packages/harness/deerflow/community/aio_sandbox/remote_backend.py +++ b/backend/packages/harness/deerflow/community/aio_sandbox/remote_backend.py @@ -59,7 +59,7 @@ class RemoteSandboxBackend(SandboxBackend): def create( self, - thread_id: str, + thread_id: str | None, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None, ) -> SandboxInfo: @@ -132,7 +132,7 @@ class RemoteSandboxBackend(SandboxBackend): logger.warning("Provisioner list_running failed: %s", exc) return [] - def _provisioner_create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo: + def _provisioner_create(self, thread_id: str | None, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo: """POST /api/sandboxes → create Pod + Service.""" try: resp = requests.post( diff --git a/backend/packages/harness/deerflow/sandbox/middleware.py b/backend/packages/harness/deerflow/sandbox/middleware.py index deefc2397..f40781333 100644 --- a/backend/packages/harness/deerflow/sandbox/middleware.py +++ b/backend/packages/harness/deerflow/sandbox/middleware.py @@ -1,3 +1,4 @@ +import asyncio import logging from typing import NotRequired, override @@ -48,6 +49,15 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]): logger.info(f"Acquiring sandbox {sandbox_id}") return sandbox_id + async def _acquire_sandbox_async(self, thread_id: str) -> str: + provider = get_sandbox_provider() + sandbox_id = await provider.acquire_async(thread_id) + logger.info(f"Acquiring sandbox {sandbox_id}") + return sandbox_id + + async def _release_sandbox_async(self, sandbox_id: str) -> None: + await asyncio.to_thread(get_sandbox_provider().release, sandbox_id) + @override def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: # Skip acquisition if lazy_init is enabled @@ -64,6 +74,23 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]): return {"sandbox": {"sandbox_id": sandbox_id}} return super().before_agent(state, runtime) + @override + async def abefore_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: + # Skip acquisition if lazy_init is enabled + if self._lazy_init: + return await super().abefore_agent(state, runtime) + + # Eager initialization (original behavior), but use the async provider + # hook so blocking sandbox startup/polling runs outside the event loop. + if "sandbox" not in state or state["sandbox"] is None: + thread_id = (runtime.context or {}).get("thread_id") + if thread_id is None: + return await super().abefore_agent(state, runtime) + sandbox_id = await self._acquire_sandbox_async(thread_id) + logger.info(f"Assigned sandbox {sandbox_id} to thread {thread_id}") + return {"sandbox": {"sandbox_id": sandbox_id}} + return await super().abefore_agent(state, runtime) + @override def after_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: sandbox = state.get("sandbox") @@ -81,3 +108,21 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]): # No sandbox to release return super().after_agent(state, runtime) + + @override + async def aafter_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: + sandbox = state.get("sandbox") + if sandbox is not None: + sandbox_id = sandbox["sandbox_id"] + logger.info(f"Releasing sandbox {sandbox_id}") + await self._release_sandbox_async(sandbox_id) + return None + + if (runtime.context or {}).get("sandbox_id") is not None: + sandbox_id = runtime.context.get("sandbox_id") + logger.info(f"Releasing sandbox {sandbox_id} from context") + await self._release_sandbox_async(sandbox_id) + return None + + # No sandbox to release + return await super().aafter_agent(state, runtime) diff --git a/backend/packages/harness/deerflow/sandbox/sandbox_provider.py b/backend/packages/harness/deerflow/sandbox/sandbox_provider.py index 0aa4d619a..b989f7830 100644 --- a/backend/packages/harness/deerflow/sandbox/sandbox_provider.py +++ b/backend/packages/harness/deerflow/sandbox/sandbox_provider.py @@ -1,3 +1,4 @@ +import asyncio from abc import ABC, abstractmethod from deerflow.config import get_app_config @@ -19,6 +20,16 @@ class SandboxProvider(ABC): """ pass + async def acquire_async(self, thread_id: str | None = None) -> str: + """Acquire a sandbox without blocking the event loop. + + Most sandbox providers expose a synchronous lifecycle API because local + Docker/provisioner operations are blocking. Async runtimes should call + this method so those blocking operations run in a worker thread instead + of stalling the event loop. + """ + return await asyncio.to_thread(self.acquire, thread_id) + @abstractmethod def get(self, sandbox_id: str) -> Sandbox | None: """Get a sandbox environment by ID. diff --git a/backend/packages/harness/deerflow/sandbox/tools.py b/backend/packages/harness/deerflow/sandbox/tools.py index 2694e9406..c8c0b06fb 100644 --- a/backend/packages/harness/deerflow/sandbox/tools.py +++ b/backend/packages/harness/deerflow/sandbox/tools.py @@ -1,6 +1,8 @@ +import asyncio import posixpath import re import shlex +from collections.abc import Callable from pathlib import Path from langchain.tools import tool @@ -1111,6 +1113,68 @@ def ensure_sandbox_initialized(runtime: Runtime | None = None) -> Sandbox: return sandbox +async def ensure_sandbox_initialized_async(runtime: Runtime | None = None) -> Sandbox: + """Async counterpart to ``ensure_sandbox_initialized`` for tool runtimes. + + This keeps lazy sandbox acquisition on the async provider hook, so AIO + sandbox startup and readiness polling do not fall back to synchronous + ``provider.acquire()`` during async tool execution. + """ + if runtime is None: + raise SandboxRuntimeError("Tool runtime not available") + + if runtime.state is None: + raise SandboxRuntimeError("Tool runtime state not available") + + sandbox_state = runtime.state.get("sandbox") + if sandbox_state is not None: + sandbox_id = sandbox_state.get("sandbox_id") + if sandbox_id is not None: + sandbox = get_sandbox_provider().get(sandbox_id) + if sandbox is not None: + if runtime.context is not None: + runtime.context["sandbox_id"] = sandbox_id + return sandbox + + thread_id = runtime.context.get("thread_id") if runtime.context else None + if thread_id is None: + thread_id = runtime.config.get("configurable", {}).get("thread_id") if runtime.config else None + if thread_id is None: + raise SandboxRuntimeError("Thread ID not available in runtime context") + + provider = get_sandbox_provider() + sandbox_id = await provider.acquire_async(thread_id) + + runtime.state["sandbox"] = {"sandbox_id": sandbox_id} + + sandbox = provider.get(sandbox_id) + if sandbox is None: + raise SandboxNotFoundError("Sandbox not found after acquisition", sandbox_id=sandbox_id) + + if runtime.context is not None: + runtime.context["sandbox_id"] = sandbox_id + return sandbox + + +async def _run_sync_tool_after_async_sandbox_init( + func: Callable[..., str] | None, + runtime: Runtime, + *args: object, +) -> str: + """Initialize lazily via async provider, then run sync tool body off-thread.""" + try: + await ensure_sandbox_initialized_async(runtime) + except SandboxError as e: + return f"Error: {e}" + except Exception as e: + return f"Error: Unexpected error initializing sandbox: {_sanitize_error(e, runtime)}" + + if func is None: + return "Error: Tool implementation not available" + + return await asyncio.to_thread(func, runtime, *args) + + def ensure_thread_directories_exist(runtime: Runtime | None) -> None: """Ensure thread data directories (workspace, uploads, outputs) exist. @@ -1273,6 +1337,13 @@ def bash_tool(runtime: Runtime, description: str, command: str) -> str: return f"Error: Unexpected error executing command: {_sanitize_error(e, runtime)}" +async def _bash_tool_async(runtime: Runtime, description: str, command: str) -> str: + return await _run_sync_tool_after_async_sandbox_init(bash_tool.func, runtime, description, command) + + +bash_tool.coroutine = _bash_tool_async + + @tool("ls", parse_docstring=True) def ls_tool(runtime: Runtime, description: str, path: str) -> str: """List the contents of a directory up to 2 levels deep in tree format. @@ -1320,6 +1391,13 @@ def ls_tool(runtime: Runtime, description: str, path: str) -> str: return f"Error: Unexpected error listing directory: {_sanitize_error(e, runtime)}" +async def _ls_tool_async(runtime: Runtime, description: str, path: str) -> str: + return await _run_sync_tool_after_async_sandbox_init(ls_tool.func, runtime, description, path) + + +ls_tool.coroutine = _ls_tool_async + + @tool("glob", parse_docstring=True) def glob_tool( runtime: Runtime, @@ -1370,6 +1448,28 @@ def glob_tool( return f"Error: Unexpected error searching paths: {_sanitize_error(e, runtime)}" +async def _glob_tool_async( + runtime: Runtime, + description: str, + pattern: str, + path: str, + include_dirs: bool = False, + max_results: int = _DEFAULT_GLOB_MAX_RESULTS, +) -> str: + return await _run_sync_tool_after_async_sandbox_init( + glob_tool.func, + runtime, + description, + pattern, + path, + include_dirs, + max_results, + ) + + +glob_tool.coroutine = _glob_tool_async + + @tool("grep", parse_docstring=True) def grep_tool( runtime: Runtime, @@ -1440,6 +1540,32 @@ def grep_tool( return f"Error: Unexpected error searching file contents: {_sanitize_error(e, runtime)}" +async def _grep_tool_async( + runtime: Runtime, + description: str, + pattern: str, + path: str, + glob: str | None = None, + literal: bool = False, + case_sensitive: bool = False, + max_results: int = _DEFAULT_GREP_MAX_RESULTS, +) -> str: + return await _run_sync_tool_after_async_sandbox_init( + grep_tool.func, + runtime, + description, + pattern, + path, + glob, + literal, + case_sensitive, + max_results, + ) + + +grep_tool.coroutine = _grep_tool_async + + @tool("read_file", parse_docstring=True) def read_file_tool( runtime: Runtime, @@ -1495,6 +1621,19 @@ def read_file_tool( return f"Error: Unexpected error reading file: {_sanitize_error(e, runtime)}" +async def _read_file_tool_async( + runtime: Runtime, + description: str, + path: str, + start_line: int | None = None, + end_line: int | None = None, +) -> str: + return await _run_sync_tool_after_async_sandbox_init(read_file_tool.func, runtime, description, path, start_line, end_line) + + +read_file_tool.coroutine = _read_file_tool_async + + @tool("write_file", parse_docstring=True) def write_file_tool( runtime: Runtime, @@ -1536,6 +1675,19 @@ def write_file_tool( return f"Error: Unexpected error writing file: {_sanitize_error(e, runtime)}" +async def _write_file_tool_async( + runtime: Runtime, + description: str, + path: str, + content: str, + append: bool = False, +) -> str: + return await _run_sync_tool_after_async_sandbox_init(write_file_tool.func, runtime, description, path, content, append) + + +write_file_tool.coroutine = _write_file_tool_async + + @tool("str_replace", parse_docstring=True) def str_replace_tool( runtime: Runtime, @@ -1585,3 +1737,25 @@ def str_replace_tool( return f"Error: Permission denied accessing file: {requested_path}" except Exception as e: return f"Error: Unexpected error replacing string: {_sanitize_error(e, runtime)}" + + +async def _str_replace_tool_async( + runtime: Runtime, + description: str, + path: str, + old_str: str, + new_str: str, + replace_all: bool = False, +) -> str: + return await _run_sync_tool_after_async_sandbox_init( + str_replace_tool.func, + runtime, + description, + path, + old_str, + new_str, + replace_all, + ) + + +str_replace_tool.coroutine = _str_replace_tool_async diff --git a/backend/tests/test_aio_sandbox_provider.py b/backend/tests/test_aio_sandbox_provider.py index 732d52170..4b3d215b3 100644 --- a/backend/tests/test_aio_sandbox_provider.py +++ b/backend/tests/test_aio_sandbox_provider.py @@ -1,5 +1,6 @@ """Tests for AioSandboxProvider mount helpers.""" +import asyncio import importlib from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -140,6 +141,182 @@ def test_discover_or_create_only_unlocks_when_lock_succeeds(tmp_path, monkeypatc assert unlock_calls == [] +@pytest.mark.anyio +async def test_acquire_async_uses_async_readiness_polling(monkeypatch): + """AioSandboxProvider async creation must not use sync readiness polling.""" + aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider") + provider = _make_provider(None) + provider._config = {"replicas": 3} + provider._thread_locks = {} + provider._warm_pool = {} + provider._sandbox_infos = {} + provider._thread_sandboxes = {} + provider._last_activity = {} + provider._lock = aio_mod.threading.Lock() + provider._backend = SimpleNamespace( + create=MagicMock(return_value=aio_mod.SandboxInfo(sandbox_id="sandbox-async", sandbox_url="http://sandbox")), + destroy=MagicMock(), + discover=MagicMock(return_value=None), + ) + + async_readiness_calls: list[tuple[str, int]] = [] + + async def fake_wait_for_sandbox_ready_async(sandbox_url: str, timeout: int = 30, poll_interval: float = 1.0) -> bool: + async_readiness_calls.append((sandbox_url, timeout)) + return True + + monkeypatch.setattr(aio_mod, "wait_for_sandbox_ready_async", fake_wait_for_sandbox_ready_async) + monkeypatch.setattr( + aio_mod, + "wait_for_sandbox_ready", + lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("sync readiness should not be used")), + ) + + sandbox_id = await provider._create_sandbox_async("thread-async", "sandbox-async") + + assert sandbox_id == "sandbox-async" + assert async_readiness_calls == [("http://sandbox", 60)] + assert provider._backend.destroy.call_count == 0 + assert provider._thread_sandboxes["thread-async"] == "sandbox-async" + + +@pytest.mark.anyio +async def test_discover_or_create_with_lock_async_offloads_lock_file_open_and_close(tmp_path, monkeypatch): + """Async lock path must not open or close lock files on the event loop.""" + aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider") + provider = _make_provider(tmp_path) + provider._discover_or_create_with_lock_async = aio_mod.AioSandboxProvider._discover_or_create_with_lock_async.__get__( + provider, + aio_mod.AioSandboxProvider, + ) + provider._thread_locks = {} + provider._warm_pool = {} + provider._sandbox_infos = {} + provider._thread_sandboxes = {"thread-async-lock": "sandbox-async-lock"} + provider._sandboxes = {"sandbox-async-lock": aio_mod.AioSandbox(id="sandbox-async-lock", base_url="http://sandbox")} + provider._last_activity = {} + provider._lock = aio_mod.threading.Lock() + provider._backend = SimpleNamespace(discover=MagicMock(return_value=None)) + + monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path)) + + to_thread_calls: list[object] = [] + + async def fake_to_thread(func, /, *args, **kwargs): + to_thread_calls.append(func) + return func(*args, **kwargs) + + monkeypatch.setattr(aio_mod.asyncio, "to_thread", fake_to_thread) + + sandbox_id = await provider._discover_or_create_with_lock_async("thread-async-lock", "sandbox-async-lock") + + assert sandbox_id == "sandbox-async-lock" + assert aio_mod._open_lock_file in to_thread_calls + assert any(getattr(func, "__name__", "") == "close" for func in to_thread_calls) + + +@pytest.mark.anyio +async def test_acquire_thread_lock_async_uses_dedicated_executor(monkeypatch): + """Per-thread lock waits should not consume the default asyncio.to_thread pool.""" + aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider") + lock = aio_mod.threading.Lock() + + async def fail_to_thread(*_args, **_kwargs): + raise AssertionError("thread-lock acquisition must not use asyncio.to_thread") + + monkeypatch.setattr(aio_mod.asyncio, "to_thread", fail_to_thread) + + await aio_mod._acquire_thread_lock_async(lock) + try: + assert not lock.acquire(blocking=False) + finally: + lock.release() + + +@pytest.mark.anyio +async def test_acquire_async_cancellation_does_not_leak_thread_lock(tmp_path): + """Cancelled async lock waiters must not leave the per-thread lock held.""" + aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider") + provider = _make_provider(tmp_path) + provider._thread_locks = {} + provider._warm_pool = {} + provider._sandbox_infos = {} + provider._thread_sandboxes = {} + provider._last_activity = {} + provider._lock = aio_mod.threading.Lock() + + thread_id = "thread-cancel-lock" + thread_lock = provider._get_thread_lock(thread_id) + thread_lock.acquire() + + task = asyncio.create_task(provider.acquire_async(thread_id)) + await asyncio.sleep(0.05) + task.cancel() + + try: + await task + except asyncio.CancelledError: + pass + + thread_lock.release() + deadline = asyncio.get_running_loop().time() + 1 + while asyncio.get_running_loop().time() < deadline: + acquired = thread_lock.acquire(blocking=False) + if acquired: + thread_lock.release() + return + await asyncio.sleep(0.01) + + pytest.fail("provider thread lock was leaked after cancelling acquire_async") + + +@pytest.mark.anyio +async def test_acquire_async_cancelled_waiter_does_not_block_successor(tmp_path, monkeypatch): + """A cancelled waiter must not prevent the next live waiter from acquiring.""" + aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider") + provider = _make_provider(tmp_path) + provider._thread_locks = {} + provider._warm_pool = {} + provider._sandbox_infos = {} + provider._thread_sandboxes = {} + provider._last_activity = {} + provider._lock = aio_mod.threading.Lock() + + async def fake_acquire_internal_async(thread_id: str | None) -> str: + assert thread_id == "thread-successor-lock" + await asyncio.sleep(0) + return "sandbox-successor" + + monkeypatch.setattr(provider, "_acquire_internal_async", fake_acquire_internal_async) + + thread_id = "thread-successor-lock" + thread_lock = provider._get_thread_lock(thread_id) + thread_lock.acquire() + + cancelled_waiter = asyncio.create_task(provider.acquire_async(thread_id)) + await asyncio.sleep(0.05) + cancelled_waiter.cancel() + try: + await cancelled_waiter + except asyncio.CancelledError: + pass + + live_waiter = asyncio.create_task(provider.acquire_async(thread_id)) + thread_lock.release() + + assert await asyncio.wait_for(live_waiter, timeout=1) == "sandbox-successor" + + deadline = asyncio.get_running_loop().time() + 1 + while asyncio.get_running_loop().time() < deadline: + acquired = thread_lock.acquire(blocking=False) + if acquired: + thread_lock.release() + return + await asyncio.sleep(0.01) + + pytest.fail("provider thread lock was not released after successor acquire_async") + + def test_remote_backend_create_forwards_effective_user_id(monkeypatch): """Provisioner mode must receive user_id so PVC subPath matches user isolation.""" remote_mod = importlib.import_module("deerflow.community.aio_sandbox.remote_backend") diff --git a/backend/tests/test_aio_sandbox_readiness.py b/backend/tests/test_aio_sandbox_readiness.py new file mode 100644 index 000000000..1560bbab3 --- /dev/null +++ b/backend/tests/test_aio_sandbox_readiness.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from deerflow.community.aio_sandbox import backend as readiness + + +class _FakeAsyncClient: + def __init__(self, *, responses: list[object], calls: list[str], timeout: float, request_timeouts: list[float] | None = None) -> None: + self._responses = responses + self._calls = calls + self._timeout = timeout + self._request_timeouts = request_timeouts + + async def __aenter__(self) -> _FakeAsyncClient: + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + return None + + async def get(self, url: str, *, timeout: float): + self._calls.append(url) + if self._request_timeouts is not None: + self._request_timeouts.append(timeout) + response = self._responses.pop(0) + if isinstance(response, BaseException): + raise response + return response + + +class _FakeLoop: + def __init__(self, times: list[float]) -> None: + self._times = times + self._index = 0 + + def time(self) -> float: + value = self._times[self._index] + self._index += 1 + return value + + +@pytest.mark.anyio +async def test_wait_for_sandbox_ready_async_uses_nonblocking_polling(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[str] = [] + sleeps: list[float] = [] + + def fake_client(*, timeout: float): + return _FakeAsyncClient( + responses=[SimpleNamespace(status_code=503), SimpleNamespace(status_code=200)], + calls=calls, + timeout=timeout, + ) + + async def fake_sleep(delay: float) -> None: + sleeps.append(delay) + + monkeypatch.setattr(readiness.httpx, "AsyncClient", fake_client) + monkeypatch.setattr(readiness.asyncio, "sleep", fake_sleep) + monkeypatch.setattr(readiness.requests, "get", lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("requests.get should not be used"))) + monkeypatch.setattr(readiness.time, "sleep", lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("time.sleep should not be used"))) + + assert await readiness.wait_for_sandbox_ready_async("http://sandbox", timeout=5, poll_interval=0.05) is True + + assert calls == ["http://sandbox/v1/sandbox", "http://sandbox/v1/sandbox"] + assert sleeps == [0.05] + + +@pytest.mark.anyio +async def test_wait_for_sandbox_ready_async_retries_request_errors(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[str] = [] + sleeps: list[float] = [] + + def fake_client(*, timeout: float): + return _FakeAsyncClient( + responses=[readiness.httpx.ConnectError("not ready"), SimpleNamespace(status_code=200)], + calls=calls, + timeout=timeout, + ) + + async def fake_sleep(delay: float) -> None: + sleeps.append(delay) + + monkeypatch.setattr(readiness.httpx, "AsyncClient", fake_client) + monkeypatch.setattr(readiness.asyncio, "sleep", fake_sleep) + + assert await readiness.wait_for_sandbox_ready_async("http://sandbox", timeout=5, poll_interval=0.01) is True + + assert len(calls) == 2 + assert sleeps == [0.01] + + +@pytest.mark.anyio +async def test_wait_for_sandbox_ready_async_clamps_request_and_sleep_to_deadline(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[str] = [] + request_timeouts: list[float] = [] + sleeps: list[float] = [] + + def fake_client(*, timeout: float): + return _FakeAsyncClient( + responses=[SimpleNamespace(status_code=503)], + calls=calls, + timeout=timeout, + request_timeouts=request_timeouts, + ) + + async def fake_sleep(delay: float) -> None: + sleeps.append(delay) + + monkeypatch.setattr(readiness.httpx, "AsyncClient", fake_client) + monkeypatch.setattr(readiness.asyncio, "sleep", fake_sleep) + monkeypatch.setattr(readiness.asyncio, "get_running_loop", lambda: _FakeLoop([100.0, 100.5, 101.75, 102.0])) + + assert await readiness.wait_for_sandbox_ready_async("http://sandbox", timeout=2, poll_interval=1.0) is False + + assert calls == ["http://sandbox/v1/sandbox"] + assert request_timeouts == [1.5] + assert sleeps == [0.25] diff --git a/backend/tests/test_remote_sandbox_backend.py b/backend/tests/test_remote_sandbox_backend.py index ed4dd7991..beb7564c5 100644 --- a/backend/tests/test_remote_sandbox_backend.py +++ b/backend/tests/test_remote_sandbox_backend.py @@ -159,6 +159,26 @@ def test_provisioner_create_returns_sandbox_info(monkeypatch): assert info.sandbox_url == "http://k3s:31001" +def test_provisioner_create_accepts_anonymous_thread_id(monkeypatch): + backend = RemoteSandboxBackend("http://provisioner:8002") + + def mock_post(url: str, json: dict, timeout: int): + assert url == "http://provisioner:8002/api/sandboxes" + assert json == { + "sandbox_id": "anon123", + "thread_id": None, + "user_id": "test-user-autouse", + } + assert timeout == 30 + return _StubResponse(payload={"sandbox_id": "anon123", "sandbox_url": "http://k3s:31002"}) + + monkeypatch.setattr(requests, "post", mock_post) + + info = backend.create(None, "anon123") + assert info.sandbox_id == "anon123" + assert info.sandbox_url == "http://k3s:31002" + + def test_provisioner_create_raises_runtime_error_on_request_exception(monkeypatch): backend = RemoteSandboxBackend("http://provisioner:8002") diff --git a/backend/tests/test_sandbox_middleware.py b/backend/tests/test_sandbox_middleware.py new file mode 100644 index 000000000..e3daa3088 --- /dev/null +++ b/backend/tests/test_sandbox_middleware.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import asyncio + +import pytest +from langchain.agents.middleware import AgentMiddleware +from langchain.tools import ToolRuntime +from langgraph.runtime import Runtime + +from deerflow.sandbox.middleware import SandboxMiddleware +from deerflow.sandbox.sandbox import Sandbox +from deerflow.sandbox.sandbox_provider import SandboxProvider, reset_sandbox_provider, set_sandbox_provider +from deerflow.sandbox.search import GrepMatch +from deerflow.sandbox.tools import ls_tool + + +class _SyncProvider(SandboxProvider): + def __init__(self) -> None: + self.thread_ids: list[str | None] = [] + + def acquire(self, thread_id: str | None = None) -> str: + self.thread_ids.append(thread_id) + return "sync-sandbox" + + def get(self, sandbox_id: str) -> Sandbox | None: + return None + + def release(self, sandbox_id: str) -> None: + return None + + +class _SandboxStub(Sandbox): + def execute_command(self, command: str) -> str: + return "OK" + + def read_file(self, path: str) -> str: + return "content" + + def download_file(self, path: str) -> bytes: + return b"content" + + def list_dir(self, path: str, max_depth: int = 2) -> list[str]: + return ["/mnt/user-data/workspace/file.txt"] + + def write_file(self, path: str, content: str, append: bool = False) -> None: + return None + + def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]: + return [], False + + def grep( + self, + path: str, + pattern: str, + *, + glob: str | None = None, + literal: bool = False, + case_sensitive: bool = False, + max_results: int = 100, + ) -> tuple[list[GrepMatch], bool]: + return [], False + + def update_file(self, path: str, content: bytes) -> None: + return None + + +class _AsyncOnlyProvider(SandboxProvider): + def __init__(self) -> None: + self.thread_ids: list[str | None] = [] + self.released_ids: list[str] = [] + self.sandbox = _SandboxStub("async-sandbox") + + def acquire(self, thread_id: str | None = None) -> str: + raise AssertionError("async middleware should not call sync acquire") + + async def acquire_async(self, thread_id: str | None = None) -> str: + self.thread_ids.append(thread_id) + return "async-sandbox" + + def get(self, sandbox_id: str) -> Sandbox | None: + if sandbox_id == "async-sandbox": + return self.sandbox + return None + + def release(self, sandbox_id: str) -> None: + self.released_ids.append(sandbox_id) + return None + + +@pytest.mark.anyio +async def test_provider_default_acquire_async_offloads_sync_acquire(monkeypatch: pytest.MonkeyPatch) -> None: + provider = _SyncProvider() + calls: list[tuple[object, tuple[object, ...]]] = [] + + async def fake_to_thread(func, /, *args): + calls.append((func, args)) + return func(*args) + + monkeypatch.setattr(asyncio, "to_thread", fake_to_thread) + + sandbox_id = await provider.acquire_async("thread-1") + + assert sandbox_id == "sync-sandbox" + assert provider.thread_ids == ["thread-1"] + assert calls == [(provider.acquire, ("thread-1",))] + + +@pytest.mark.anyio +async def test_abefore_agent_uses_async_provider_acquire() -> None: + provider = _AsyncOnlyProvider() + set_sandbox_provider(provider) + try: + middleware = SandboxMiddleware(lazy_init=False) + + result = await middleware.abefore_agent({}, Runtime(context={"thread_id": "thread-2"})) + finally: + reset_sandbox_provider() + + assert result == {"sandbox": {"sandbox_id": "async-sandbox"}} + assert provider.thread_ids == ["thread-2"] + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("middleware", "state", "runtime"), + [ + (SandboxMiddleware(lazy_init=True), {}, Runtime(context={"thread_id": "thread-lazy"})), + (SandboxMiddleware(lazy_init=False), {}, Runtime(context={})), + (SandboxMiddleware(lazy_init=False), {"sandbox": {"sandbox_id": "existing"}}, Runtime(context={"thread_id": "thread-existing"})), + ], +) +async def test_abefore_agent_delegates_to_super_when_not_acquiring( + monkeypatch: pytest.MonkeyPatch, + middleware: SandboxMiddleware, + state: dict, + runtime: Runtime, +) -> None: + calls: list[tuple[dict, Runtime]] = [] + + async def fake_super_abefore_agent(self, state_arg, runtime_arg): + calls.append((state_arg, runtime_arg)) + return {"delegated": True} + + monkeypatch.setattr(AgentMiddleware, "abefore_agent", fake_super_abefore_agent) + + result = await middleware.abefore_agent(state, runtime) + + assert result == {"delegated": True} + assert calls == [(state, runtime)] + + +@pytest.mark.anyio +async def test_default_lazy_tool_acquisition_uses_async_provider() -> None: + provider = _AsyncOnlyProvider() + set_sandbox_provider(provider) + try: + runtime = ToolRuntime( + state={}, + context={"thread_id": "thread-lazy"}, + config={"configurable": {}}, + stream_writer=lambda _: None, + tools=[], + tool_call_id="call-1", + store=None, + ) + + result = await ls_tool.ainvoke({"runtime": runtime, "description": "list workspace", "path": "/mnt/user-data/workspace"}) + finally: + reset_sandbox_provider() + + assert result == "/mnt/user-data/workspace/file.txt" + assert provider.thread_ids == ["thread-lazy"] + assert runtime.state["sandbox"] == {"sandbox_id": "async-sandbox"} + assert runtime.context["sandbox_id"] == "async-sandbox" + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("state", "runtime", "expected_sandbox_id"), + [ + ({"sandbox": {"sandbox_id": "state-sandbox"}}, Runtime(context={}), "state-sandbox"), + ({}, Runtime(context={"sandbox_id": "context-sandbox"}), "context-sandbox"), + ], +) +async def test_aafter_agent_releases_sandbox_off_thread( + monkeypatch: pytest.MonkeyPatch, + state: dict, + runtime: Runtime, + expected_sandbox_id: str, +) -> None: + provider = _AsyncOnlyProvider() + to_thread_calls: list[tuple[object, tuple[object, ...]]] = [] + + async def fake_to_thread(func, /, *args): + to_thread_calls.append((func, args)) + return func(*args) + + monkeypatch.setattr(asyncio, "to_thread", fake_to_thread) + set_sandbox_provider(provider) + try: + result = await SandboxMiddleware().aafter_agent(state, runtime) + finally: + reset_sandbox_provider() + + assert result is None + assert provider.released_ids == [expected_sandbox_id] + assert to_thread_calls == [(provider.release, (expected_sandbox_id,))] + + +@pytest.mark.anyio +async def test_aafter_agent_delegates_to_super_when_no_sandbox(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[tuple[dict, Runtime]] = [] + + async def fake_super_aafter_agent(self, state_arg, runtime_arg): + calls.append((state_arg, runtime_arg)) + return {"delegated": True} + + monkeypatch.setattr(AgentMiddleware, "aafter_agent", fake_super_aafter_agent) + + state = {} + runtime = Runtime(context={}) + result = await SandboxMiddleware().aafter_agent(state, runtime) + + assert result == {"delegated": True} + assert calls == [(state, runtime)]