diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 01da37b05..ac12023b2 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -284,7 +284,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti **Provider Pattern**: `SandboxProvider` with `acquire`, `acquire_async`, `get`, `release` lifecycle. Async agent/tool paths call async sandbox lifecycle hooks so Docker sandbox creation, discovery, cross-process locking, readiness polling, and release stay off the event loop. **Implementations**: - `LocalSandboxProvider` - Local filesystem execution. `acquire(thread_id)` returns a per-thread `LocalSandbox` (id `local:{thread_id}`) whose `path_mappings` resolve `/mnt/user-data/{workspace,uploads,outputs}` and `/mnt/acp-workspace` to that thread's host directories, so the public `Sandbox` API honours the `/mnt/user-data` contract uniformly with AIO. `acquire()` / `acquire(None)` keeps the legacy generic singleton (id `local`) for callers without a thread context. Per-thread sandboxes are held in an LRU cache (default 256 entries) guarded by a `threading.Lock`. -- `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation +- `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation. Active-cache and warm-pool entries are checked with the backend during acquire/reuse; definitively dead containers are dropped from all in-process maps so the thread can discover or create a fresh sandbox instead of reusing a stale client. Backend health-check failures are treated as unknown, not dead; local discovery likewise treats an unverifiable container as not adoptable and falls through to create rather than failing acquire. `get()` remains an in-memory lookup for event-loop-safe tool paths. **Virtual Path System**: - Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills` diff --git a/backend/README.md b/backend/README.md index 20ef72d50..04f3f67bd 100644 --- a/backend/README.md +++ b/backend/README.md @@ -69,7 +69,7 @@ Middlewares execute in strict order, each handling a specific concern: Per-thread isolated execution with virtual path translation: - **Abstract interface**: `execute_command`, `read_file`, `write_file`, `list_dir` -- **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/). Async runtime paths use async sandbox lifecycle hooks so startup, readiness polling, and release do not block the event loop. +- **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/). Async runtime paths use async sandbox lifecycle hooks so startup, readiness polling, and release do not block the event loop. `AioSandboxProvider` validates active-cache and warm-pool containers during acquire/reuse, dropping definitively dead entries so a thread can provision a fresh sandbox after an unexpected container exit while keeping `get()` as an in-memory lookup. Backend health-check failures are treated as unknown, not dead, and a container that cannot be verified during discovery is simply not adopted (acquire falls through to create instead of failing). - **Virtual paths**: `/mnt/user-data/{workspace,uploads,outputs}` → thread-specific physical directories - **Skills path**: `/mnt/skills` → `deer-flow/skills/` directory - **Skills loading**: Recursively discovers nested `SKILL.md` files under `skills/{public,custom}` and preserves nested container paths diff --git a/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox_provider.py b/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox_provider.py index ec84f23df..095663367 100644 --- a/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox_provider.py +++ b/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox_provider.py @@ -470,14 +470,32 @@ class AioSandboxProvider(SandboxProvider): existing_id = self._thread_sandboxes[thread_id] if existing_id in self._sandboxes: - suffix = " (post-lock check)" if post_lock else "" - logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}{suffix}") - self._last_activity[existing_id] = time.time() - return existing_id + info = self._sandbox_infos.get(existing_id) + else: + del self._thread_sandboxes[thread_id] + return None - del self._thread_sandboxes[thread_id] + alive = self._check_tracked_sandbox_alive(existing_id, info) if info is not None else True + if alive is False: + self._drop_unhealthy_sandbox( + existing_id, + "in-process cache failed health check", + expected_info=info, + ) return None + with self._lock: + if self._thread_sandboxes.get(thread_id) != existing_id: + return None + if existing_id not in self._sandboxes: + self._thread_sandboxes.pop(thread_id, None) + return None + + suffix = " (post-lock check)" if post_lock else "" + logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}{suffix}") + self._last_activity[existing_id] = time.time() + return existing_id + def _reclaim_warm_pool_sandbox(self, thread_id: str | None, sandbox_id: str, *, post_lock: bool = False) -> str | None: """Promote a warm-pool sandbox back to active tracking if available.""" if thread_id is None: @@ -487,7 +505,22 @@ class AioSandboxProvider(SandboxProvider): if sandbox_id not in self._warm_pool: return None - info, _ = self._warm_pool.pop(sandbox_id) + info, _ = self._warm_pool[sandbox_id] + + alive = self._check_tracked_sandbox_alive(sandbox_id, info) + if alive is False: + self._drop_unhealthy_sandbox( + sandbox_id, + "warm-pool cache failed health check", + expected_info=info, + ) + return None + + with self._lock: + warm_item = self._warm_pool.pop(sandbox_id, None) + if warm_item is None: + return None + info, _ = warm_item sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url) self._sandboxes[sandbox_id] = sandbox self._sandbox_infos[sandbox_id] = info @@ -527,6 +560,70 @@ class AioSandboxProvider(SandboxProvider): logger.info(f"Created sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}") return sandbox_id + def _check_tracked_sandbox_alive(self, sandbox_id: str, info: SandboxInfo) -> bool | None: + """Return whether a tracked sandbox appears alive, or None if unknown.""" + try: + return self._backend.is_alive(info) + except Exception as e: + logger.warning(f"Failed to check sandbox {sandbox_id} health: {e}") + return None + + def _remove_tracked_sandbox( + self, + sandbox_id: str, + *, + expected_info: SandboxInfo | None = None, + ) -> tuple[Sandbox | None, SandboxInfo | None, bool]: + """Remove a sandbox from in-process tracking maps. + + When expected_info is provided, removal only happens if the currently + tracked active or warm-pool entry is the exact info object that was + checked. This prevents a stale health-check result from deleting a + freshly recreated sandbox with the same deterministic id. + """ + thread_ids_to_remove: list[str] = [] + + with self._lock: + active_info = self._sandbox_infos.get(sandbox_id) + warm_item = self._warm_pool.get(sandbox_id) + warm_info = warm_item[0] if warm_item is not None else None + if expected_info is not None and active_info is not expected_info and warm_info is not expected_info: + return None, None, False + + sandbox = self._sandboxes.pop(sandbox_id, None) + info = self._sandbox_infos.pop(sandbox_id, None) + thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id] + for tid in thread_ids_to_remove: + del self._thread_sandboxes[tid] + self._last_activity.pop(sandbox_id, None) + if info is None and sandbox_id in self._warm_pool: + info, _ = self._warm_pool.pop(sandbox_id) + else: + self._warm_pool.pop(sandbox_id, None) + + return sandbox, info, True + + def _drop_unhealthy_sandbox(self, sandbox_id: str, reason: str, *, expected_info: SandboxInfo | None = None) -> None: + """Remove and destroy a sandbox after a definitive failed health check.""" + sandbox, info, removed = self._remove_tracked_sandbox(sandbox_id, expected_info=expected_info) + if not removed: + logger.info(f"Skipped dropping sandbox {sandbox_id}: tracked info changed after health check") + return + + if sandbox is not None: + try: + sandbox.close() + except Exception as e: + logger.warning(f"Error closing unhealthy sandbox {sandbox_id}: {e}") + + if info is not None: + try: + self._backend.destroy(info) + except Exception as e: + logger.warning(f"Error destroying unhealthy sandbox {sandbox_id}: {e}") + + logger.warning(f"Dropped unhealthy sandbox {sandbox_id}: {reason}") + def _replica_count(self) -> tuple[int, int]: """Return configured replicas and currently tracked sandbox count.""" replicas = self._config.get("replicas", DEFAULT_REPLICAS) @@ -617,7 +714,7 @@ class AioSandboxProvider(SandboxProvider): async def _acquire_internal_async(self, thread_id: str | None) -> str: """Async counterpart to ``_acquire_internal``.""" - cached_id = self._reuse_in_process_sandbox(thread_id) + cached_id = await asyncio.to_thread(self._reuse_in_process_sandbox, thread_id) if cached_id is not None: return cached_id @@ -625,7 +722,7 @@ class AioSandboxProvider(SandboxProvider): sandbox_id = self._sandbox_id_for_thread(thread_id) # ── Layer 1.5: Warm pool (container still running, no cold-start) ── - reclaimed_id = self._reclaim_warm_pool_sandbox(thread_id, sandbox_id) + reclaimed_id = await asyncio.to_thread(self._reclaim_warm_pool_sandbox, thread_id, sandbox_id) if reclaimed_id is not None: return reclaimed_id @@ -681,7 +778,7 @@ class AioSandboxProvider(SandboxProvider): locked = True # Re-check in-process caches under the file lock in case another # thread in this process won the race while we were waiting. - cached_id = self._recheck_cached_sandbox(thread_id, sandbox_id) + cached_id = await asyncio.to_thread(self._recheck_cached_sandbox, thread_id, sandbox_id) if cached_id is not None: return cached_id @@ -837,22 +934,7 @@ class AioSandboxProvider(SandboxProvider): Args: sandbox_id: The ID of the sandbox to destroy. """ - info = None - sandbox = None - thread_ids_to_remove: list[str] = [] - - with self._lock: - sandbox = self._sandboxes.pop(sandbox_id, None) - info = self._sandbox_infos.pop(sandbox_id, None) - thread_ids_to_remove = [tid for tid, sid in self._thread_sandboxes.items() if sid == sandbox_id] - for tid in thread_ids_to_remove: - del self._thread_sandboxes[tid] - self._last_activity.pop(sandbox_id, None) - # Also pull from warm pool if it was parked there - if info is None and sandbox_id in self._warm_pool: - info, _ = self._warm_pool.pop(sandbox_id) - else: - self._warm_pool.pop(sandbox_id, None) + sandbox, info, _ = self._remove_tracked_sandbox(sandbox_id) if sandbox is not None: # Defense-in-depth: close() already swallows its own errors; this diff --git a/backend/packages/harness/deerflow/community/aio_sandbox/local_backend.py b/backend/packages/harness/deerflow/community/aio_sandbox/local_backend.py index 69d838208..58eaf6817 100644 --- a/backend/packages/harness/deerflow/community/aio_sandbox/local_backend.py +++ b/backend/packages/harness/deerflow/community/aio_sandbox/local_backend.py @@ -169,6 +169,24 @@ def _resolve_docker_bind_host(sandbox_host: str | None = None, bind_host: str | return "0.0.0.0" +def _is_no_such_container_error(stderr: str, container_name: str) -> bool: + """Return True only when stderr definitively says the container does not exist. + + Docker reports "No such object" / "No such container". Apple Container + reports a generic "not found", so that phrase is only trusted when the + message also names the inspected container (or refers to a + container/object); transient failures whose text happens to contain + "not found" (e.g. "command not found", "context not found") must stay on + the raise path instead of being misread as a dead container. + """ + message = stderr.lower() + if "no such object" in message or "no such container" in message: + return True + if "not found" not in message: + return False + return container_name.lower() in message or "container" in message or "object" in message + + class LocalContainerBackend(SandboxBackend): """Backend that manages sandbox containers locally using Docker or Apple Container. @@ -335,11 +353,21 @@ class LocalContainerBackend(SandboxBackend): sandbox_id: The deterministic sandbox ID (determines container name). Returns: - SandboxInfo if container found and healthy, None otherwise. + SandboxInfo if container found and healthy, None otherwise. A + failed runtime check (e.g. transient daemon error) also returns + None — discovery must not adopt a container it cannot verify, and + falling through to create keeps acquire recoverable instead of + hard-failing on a hiccup. """ container_name = f"{self._container_prefix}-{sandbox_id}" - if not self._is_container_running(container_name): + try: + running = self._is_container_running(container_name) + except RuntimeError as e: + logger.warning(f"Could not verify container {container_name} during discovery; not adopting it: {e}") + return None + + if not running: return None port = self._get_container_port(container_name) @@ -582,6 +610,13 @@ class LocalContainerBackend(SandboxBackend): This enables cross-process container discovery — any process can detect containers started by another process via the deterministic container name. + + Raises: + RuntimeError: If the container runtime cannot answer the inspect + query. A failed check is intentionally distinct from a + definitive "container does not exist" result so callers do not + destroy healthy containers during transient Docker/Container + daemon failures. """ try: result = subprocess.run( @@ -590,9 +625,14 @@ class LocalContainerBackend(SandboxBackend): text=True, timeout=5, ) - return result.returncode == 0 and result.stdout.strip().lower() == "true" - except (subprocess.CalledProcessError, subprocess.TimeoutExpired): + except subprocess.TimeoutExpired as exc: + raise RuntimeError(f"Timed out checking container {container_name}") from exc + + if result.returncode == 0: + return result.stdout.strip().lower() == "true" + if _is_no_such_container_error(result.stderr, container_name): return False + raise RuntimeError(f"Failed to inspect container {container_name}: {result.stderr.strip()}") def _get_container_port(self, container_name: str) -> int | None: """Get the host port of a running container. diff --git a/backend/packages/harness/deerflow/community/aio_sandbox/remote_backend.py b/backend/packages/harness/deerflow/community/aio_sandbox/remote_backend.py index 83925df13..c04f4e75d 100644 --- a/backend/packages/harness/deerflow/community/aio_sandbox/remote_backend.py +++ b/backend/packages/harness/deerflow/community/aio_sandbox/remote_backend.py @@ -176,12 +176,16 @@ class RemoteSandboxBackend(SandboxBackend): f"{self._provisioner_url}/api/sandboxes/{sandbox_id}", timeout=10, ) - if resp.ok: - data = resp.json() - return data.get("status") == "Running" - return False - except requests.RequestException: + except requests.RequestException as exc: + raise RuntimeError(f"Provisioner health check failed for {sandbox_id}: {exc}") from exc + + if resp.status_code == 404: return False + if not resp.ok: + raise RuntimeError(f"Provisioner health check failed for {sandbox_id}: HTTP {resp.status_code} {resp.text}") + + data = resp.json() + return data.get("status") == "Running" def _provisioner_discover(self, sandbox_id: str) -> SandboxInfo | None: """GET /api/sandboxes/{sandbox_id} → discover existing sandbox.""" diff --git a/backend/packages/harness/deerflow/sandbox/middleware.py b/backend/packages/harness/deerflow/sandbox/middleware.py index f40781333..5bdb5a700 100644 --- a/backend/packages/harness/deerflow/sandbox/middleware.py +++ b/backend/packages/harness/deerflow/sandbox/middleware.py @@ -1,10 +1,15 @@ import asyncio import logging +from collections.abc import Awaitable, Callable +from dataclasses import replace as dc_replace from typing import NotRequired, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware +from langchain_core.messages import ToolMessage +from langgraph.prebuilt.tool_node import ToolCallRequest from langgraph.runtime import Runtime +from langgraph.types import Command from deerflow.agents.thread_state import SandboxState, ThreadDataState from deerflow.sandbox import get_sandbox_provider @@ -126,3 +131,87 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]): # No sandbox to release return await super().aafter_agent(state, runtime) + + # ------------------------------------------------------------------ + # Tool-call wrappers: persist lazily-acquired sandbox state into the + # graph state via Command(update=...). + # + # Background: + # ``ensure_sandbox_initialized*`` in ``deerflow.sandbox.tools`` mutates + # ``runtime.state["sandbox"]`` directly. That mutation is local to the + # current tool invocation and is NOT picked up by LangGraph's channel + # reducer, so subsequent graph steps (and downstream consumers such as + # ``ToolOutputBudgetMiddleware`` and the sub-agent ``task_tool``) + # cannot observe the sandbox id. Wrapping the tool call lets us detect + # a fresh lazy init by diffing the state snapshot before/after the + # handler and emit a proper state update via ``Command``. + # ------------------------------------------------------------------ + + @staticmethod + def _read_sandbox_id_from_state(state: object) -> str | None: + if not isinstance(state, dict): + return None + sandbox_state = state.get("sandbox") + if not isinstance(sandbox_state, dict): + return None + sandbox_id = sandbox_state.get("sandbox_id") + return sandbox_id if isinstance(sandbox_id, str) else None + + @staticmethod + def _attach_sandbox_update(result: ToolMessage | Command, sandbox_id: str) -> ToolMessage | Command: + """Wrap or merge ``result`` so that ``sandbox.sandbox_id`` is persisted. + + - ``ToolMessage`` -> ``Command(update={"sandbox": ..., "messages": [msg]})`` + - ``Command`` with dict update -> merge ``sandbox`` key, preserve all + existing fields (``messages``, ``goto``, ``graph``, ``resume``, ...). + - ``Command`` with non-dict / None update -> leave it untouched to + avoid silent data loss on unknown update shapes. + """ + sandbox_update = {"sandbox": {"sandbox_id": sandbox_id}} + + if isinstance(result, ToolMessage): + return Command(update={**sandbox_update, "messages": [result]}) + + existing_update = result.update + if isinstance(existing_update, dict): + merged_update = {**existing_update, **sandbox_update} + return dc_replace(result, update=merged_update) + return result + + @staticmethod + def _read_sandbox_id_from_request(request: ToolCallRequest) -> str | None: + """Read sandbox_id from runtime.state (where ensure_sandbox_initialized writes).""" + runtime = request.runtime + if runtime is None or runtime.state is None: + return None + return SandboxMiddleware._read_sandbox_id_from_state(runtime.state) + + @override + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage | Command], + ) -> ToolMessage | Command: + prev_sandbox_id = self._read_sandbox_id_from_request(request) + result = handler(request) + if prev_sandbox_id is not None: + return result + curr_sandbox_id = self._read_sandbox_id_from_request(request) + if curr_sandbox_id is None: + return result + return self._attach_sandbox_update(result, curr_sandbox_id) + + @override + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], + ) -> ToolMessage | Command: + prev_sandbox_id = self._read_sandbox_id_from_request(request) + result = await handler(request) + if prev_sandbox_id is not None: + return result + curr_sandbox_id = self._read_sandbox_id_from_request(request) + if curr_sandbox_id is None: + return result + return self._attach_sandbox_update(result, curr_sandbox_id) diff --git a/backend/tests/fixtures/replay/write_read_file.ultra.events.json b/backend/tests/fixtures/replay/write_read_file.ultra.events.json index babf90d9d..83976df10 100644 --- a/backend/tests/fixtures/replay/write_read_file.ultra.events.json +++ b/backend/tests/fixtures/replay/write_read_file.ultra.events.json @@ -69,6 +69,7 @@ "keys": [ "artifacts", "messages", + "sandbox", "thread_data", "title", "viewed_images" @@ -79,6 +80,7 @@ "keys": [ "artifacts", "messages", + "sandbox", "thread_data", "title", "viewed_images" @@ -89,6 +91,7 @@ "keys": [ "artifacts", "messages", + "sandbox", "thread_data", "title", "viewed_images" @@ -99,6 +102,7 @@ "keys": [ "artifacts", "messages", + "sandbox", "thread_data", "title", "viewed_images" @@ -109,6 +113,7 @@ "keys": [ "artifacts", "messages", + "sandbox", "thread_data", "title", "viewed_images" @@ -119,6 +124,7 @@ "keys": [ "artifacts", "messages", + "sandbox", "thread_data", "title", "viewed_images" diff --git a/backend/tests/test_aio_sandbox_local_backend.py b/backend/tests/test_aio_sandbox_local_backend.py index 333c3eb53..bf0767bc7 100644 --- a/backend/tests/test_aio_sandbox_local_backend.py +++ b/backend/tests/test_aio_sandbox_local_backend.py @@ -1,7 +1,10 @@ import logging import os +import subprocess from types import SimpleNamespace +import pytest + from deerflow.community.aio_sandbox.local_backend import ( LocalContainerBackend, _format_container_command_for_log, @@ -234,3 +237,99 @@ def test_start_container_keeps_apple_container_port_format(monkeypatch): captured_cmd = _capture_start_container_command(monkeypatch, backend, runtime="container") assert captured_cmd[captured_cmd.index("-p") + 1] == "18080:8080" + + +def _backend_for_inspect_tests() -> LocalContainerBackend: + backend = LocalContainerBackend( + image="sandbox:latest", + base_port=8080, + container_prefix="sandbox", + config_mounts=[], + environment={}, + ) + backend._runtime = "docker" + return backend + + +def test_is_container_running_false_when_container_missing(monkeypatch): + backend = _backend_for_inspect_tests() + + def fake_run(cmd, **kwargs): + return SimpleNamespace(stdout="", stderr="Error: No such object: sandbox-missing", returncode=1) + + monkeypatch.setattr("subprocess.run", fake_run) + + assert backend._is_container_running("sandbox-missing") is False + + +def test_is_container_running_raises_on_runtime_error(monkeypatch): + backend = _backend_for_inspect_tests() + + def fake_run(cmd, **kwargs): + return SimpleNamespace(stdout="", stderr="Cannot connect to the Docker daemon", returncode=1) + + monkeypatch.setattr("subprocess.run", fake_run) + + with pytest.raises(RuntimeError, match="Failed to inspect container sandbox-busy"): + backend._is_container_running("sandbox-busy") + + +def test_is_container_running_raises_on_timeout(monkeypatch): + backend = _backend_for_inspect_tests() + + def fake_run(cmd, **kwargs): + raise subprocess.TimeoutExpired(cmd=cmd, timeout=kwargs["timeout"]) + + monkeypatch.setattr("subprocess.run", fake_run) + + with pytest.raises(RuntimeError, match="Timed out checking container sandbox-timeout"): + backend._is_container_running("sandbox-timeout") + + +def test_discover_returns_none_when_runtime_check_fails(monkeypatch): + """A transient daemon error during discovery must fall through to create, not fail acquire.""" + backend = _backend_for_inspect_tests() + + def fake_run(cmd, **kwargs): + return SimpleNamespace(stdout="", stderr="Cannot connect to the Docker daemon", returncode=1) + + monkeypatch.setattr("subprocess.run", fake_run) + + assert backend.discover("sandbox-blip") is None + + +def test_discover_returns_none_when_runtime_check_times_out(monkeypatch): + """An inspect timeout during discovery must not propagate out of discover().""" + backend = _backend_for_inspect_tests() + + def fake_run(cmd, **kwargs): + raise subprocess.TimeoutExpired(cmd=cmd, timeout=kwargs["timeout"]) + + monkeypatch.setattr("subprocess.run", fake_run) + + assert backend.discover("sandbox-timeout") is None + + +def test_is_container_running_false_on_apple_container_not_found(monkeypatch): + """Apple Container's generic "not found" is trusted when it names the container.""" + backend = _backend_for_inspect_tests() + + def fake_run(cmd, **kwargs): + return SimpleNamespace(stdout="", stderr='Error: not found: "sandbox-apple"', returncode=1) + + monkeypatch.setattr("subprocess.run", fake_run) + + assert backend._is_container_running("sandbox-apple") is False + + +def test_is_container_running_raises_on_unrelated_not_found_error(monkeypatch): + """Transient errors whose text contains "not found" must not be misread as a dead container.""" + backend = _backend_for_inspect_tests() + + def fake_run(cmd, **kwargs): + return SimpleNamespace(stdout="", stderr="Error: credential helper not found in $PATH", returncode=1) + + monkeypatch.setattr("subprocess.run", fake_run) + + with pytest.raises(RuntimeError, match="Failed to inspect container sandbox-busy"): + backend._is_container_running("sandbox-busy") diff --git a/backend/tests/test_aio_sandbox_provider.py b/backend/tests/test_aio_sandbox_provider.py index ada99e744..974997a69 100644 --- a/backend/tests/test_aio_sandbox_provider.py +++ b/backend/tests/test_aio_sandbox_provider.py @@ -317,6 +317,28 @@ async def test_acquire_async_cancelled_waiter_does_not_block_successor(tmp_path, pytest.fail("provider thread lock was not released after successor acquire_async") +@pytest.mark.anyio +async def test_acquire_internal_async_offloads_cached_reuse_health_check(tmp_path, monkeypatch): + """Async cached reuse must keep backend health checks off the event loop.""" + aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider") + provider, _sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-cached-async") + provider._thread_sandboxes = {"thread-cached-async": "sandbox-cached-async"} + provider._backend.is_alive = MagicMock(return_value=True) + + to_thread_calls: list[tuple[object, tuple[object, ...]]] = [] + + async def fake_to_thread(func, /, *args, **kwargs): + to_thread_calls.append((func, args)) + return func(*args, **kwargs) + + monkeypatch.setattr(aio_mod.asyncio, "to_thread", fake_to_thread) + + sandbox_id = await provider._acquire_internal_async("thread-cached-async") + + assert sandbox_id == "sandbox-cached-async" + assert to_thread_calls == [(provider._reuse_in_process_sandbox, ("thread-cached-async",))] + + def test_remote_backend_create_forwards_effective_user_id(monkeypatch): """Provisioner mode must receive user_id so PVC subPath matches user isolation.""" remote_mod = importlib.import_module("deerflow.community.aio_sandbox.remote_backend") @@ -424,6 +446,136 @@ def test_release_swallows_close_errors(tmp_path, caplog): assert "sandbox-rel-err" in provider._warm_pool +def test_get_uses_in_memory_registry_only(tmp_path): + """get() must stay event-loop safe by avoiding backend health checks.""" + provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-dead") + provider._backend.is_alive = MagicMock(side_effect=AssertionError("get must not call backend health checks")) + + assert provider.get("sandbox-dead") is sandbox + + +def test_acquire_drops_dead_cached_sandbox(tmp_path, monkeypatch): + """acquire() must replace a stale active cache entry after its container dies.""" + aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider") + provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-dead") + provider._thread_locks = {} + provider._thread_sandboxes = {"thread-dead": "sandbox-dead"} + provider._config = {"replicas": 3} + provider._backend.is_alive = MagicMock(return_value=False) + provider._backend.discover = MagicMock(return_value=None) + provider._backend.create = MagicMock( + return_value=aio_mod.SandboxInfo( + sandbox_id="sandbox-dead", + sandbox_url="http://fresh-sandbox", + container_name="deer-flow-sandbox-sandbox-dead", + ) + ) + + monkeypatch.setattr(aio_mod.AioSandboxProvider, "_sandbox_id_for_thread", lambda _self, _thread_id: "sandbox-dead") + monkeypatch.setattr(aio_mod.AioSandboxProvider, "_get_extra_mounts", lambda _self, _thread_id: []) + monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path)) + monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None) + monkeypatch.setattr(aio_mod, "wait_for_sandbox_ready", lambda _url, timeout=60: True) + + sandbox_id = provider.acquire("thread-dead") + + assert sandbox_id == "sandbox-dead" + sandbox.close.assert_called_once_with() + provider._backend.destroy.assert_called_once() + provider._backend.create.assert_called_once() + assert provider._thread_sandboxes["thread-dead"] == "sandbox-dead" + assert provider._sandboxes["sandbox-dead"].base_url == "http://fresh-sandbox" + + +def test_acquire_keeps_cached_sandbox_when_health_check_errors(tmp_path): + """Transient backend health-check errors must not destroy a tracked sandbox.""" + provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-transient") + provider._thread_locks = {} + provider._thread_sandboxes = {"thread-transient": "sandbox-transient"} + provider._backend.is_alive = MagicMock(side_effect=OSError("docker daemon busy")) + + sandbox_id = provider.acquire("thread-transient") + + assert sandbox_id == "sandbox-transient" + sandbox.close.assert_not_called() + provider._backend.destroy.assert_not_called() + assert provider._sandboxes["sandbox-transient"] is sandbox + + +def test_drop_unhealthy_sandbox_skips_recreated_entry(tmp_path): + """A stale health-check result must not delete a newly registered sandbox.""" + aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider") + provider = _make_provider(tmp_path) + provider._lock = aio_mod.threading.Lock() + provider._warm_pool = {} + provider._last_activity = {"sandbox-toctou": 1.0} + provider._thread_sandboxes = {"thread-toctou": "sandbox-toctou"} + old_info = aio_mod.SandboxInfo(sandbox_id="sandbox-toctou", sandbox_url="http://old-sandbox") + new_info = aio_mod.SandboxInfo(sandbox_id="sandbox-toctou", sandbox_url="http://new-sandbox") + new_sandbox = MagicMock() + provider._sandbox_infos = {"sandbox-toctou": new_info} + provider._sandboxes = {"sandbox-toctou": new_sandbox} + provider._backend = SimpleNamespace(destroy=MagicMock()) + + provider._drop_unhealthy_sandbox("sandbox-toctou", "stale health check", expected_info=old_info) + + new_sandbox.close.assert_not_called() + provider._backend.destroy.assert_not_called() + assert provider._sandbox_infos["sandbox-toctou"] is new_info + assert provider._sandboxes["sandbox-toctou"] is new_sandbox + assert provider._thread_sandboxes == {"thread-toctou": "sandbox-toctou"} + + +def test_acquire_skips_dead_warm_pool_sandbox(tmp_path, monkeypatch): + """acquire() must create a fresh sandbox when the warm-pool entry died.""" + aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider") + provider = _make_provider(tmp_path) + provider._lock = aio_mod.threading.Lock() + provider._thread_locks = {} + provider._sandboxes = {} + provider._sandbox_infos = {} + provider._thread_sandboxes = {} + provider._last_activity = {} + provider._warm_pool = { + "sandbox-warm-dead": ( + aio_mod.SandboxInfo( + sandbox_id="sandbox-warm-dead", + sandbox_url="http://stale-sandbox", + container_name="deer-flow-sandbox-sandbox-warm-dead", + ), + 0.0, + ) + } + provider._config = {"replicas": 3} + provider._backend = SimpleNamespace( + is_alive=MagicMock(return_value=False), + destroy=MagicMock(), + discover=MagicMock(return_value=None), + create=MagicMock( + return_value=aio_mod.SandboxInfo( + sandbox_id="sandbox-warm-dead", + sandbox_url="http://fresh-sandbox", + container_name="deer-flow-sandbox-sandbox-warm-dead", + ) + ), + ) + + monkeypatch.setattr(aio_mod.AioSandboxProvider, "_sandbox_id_for_thread", lambda _self, _thread_id: "sandbox-warm-dead") + monkeypatch.setattr(aio_mod.AioSandboxProvider, "_get_extra_mounts", lambda _self, _thread_id: []) + monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path)) + monkeypatch.setattr(aio_mod, "get_effective_user_id", lambda: None) + monkeypatch.setattr(aio_mod, "wait_for_sandbox_ready", lambda _url, timeout=60: True) + + sandbox_id = provider.acquire("thread-warm-dead") + + assert sandbox_id == "sandbox-warm-dead" + provider._backend.destroy.assert_called_once() + provider._backend.create.assert_called_once() + assert provider._warm_pool == {} + assert provider._thread_sandboxes["thread-warm-dead"] == "sandbox-warm-dead" + assert provider._sandboxes["sandbox-warm-dead"].base_url == "http://fresh-sandbox" + + def test_destroy_swallows_close_errors_and_still_destroys_backend(tmp_path, caplog): """A failure in sandbox.close() must not skip backend container destruction.""" provider, sandbox, _ = _make_provider_with_active_sandbox(tmp_path, "sandbox-dest-err") diff --git a/backend/tests/test_remote_sandbox_backend.py b/backend/tests/test_remote_sandbox_backend.py index beb7564c5..3c6f60886 100644 --- a/backend/tests/test_remote_sandbox_backend.py +++ b/backend/tests/test_remote_sandbox_backend.py @@ -257,14 +257,38 @@ def test_provisioner_is_alive_true_only_when_status_running(monkeypatch): assert backend._provisioner_is_alive("abc123") is False -def test_provisioner_is_alive_returns_false_on_request_exception(monkeypatch): +def test_provisioner_is_alive_returns_false_on_404(monkeypatch): + backend = RemoteSandboxBackend("http://provisioner:8002") + + def mock_get(url: str, timeout: int): + return _StubResponse(status_code=404) + + monkeypatch.setattr(requests, "get", mock_get) + assert backend._provisioner_is_alive("abc123") is False + + +def test_provisioner_is_alive_raises_on_request_exception(monkeypatch): backend = RemoteSandboxBackend("http://provisioner:8002") def mock_get(url: str, timeout: int): raise requests.RequestException("boom") monkeypatch.setattr(requests, "get", mock_get) - assert backend._provisioner_is_alive("abc123") is False + with pytest.raises(RuntimeError, match="Provisioner health check failed for abc123"): + backend._provisioner_is_alive("abc123") + + +def test_provisioner_is_alive_raises_on_server_error(monkeypatch): + backend = RemoteSandboxBackend("http://provisioner:8002") + + def mock_get(url: str, timeout: int): + response = _StubResponse(status_code=503) + response.text = "unavailable" + return response + + monkeypatch.setattr(requests, "get", mock_get) + with pytest.raises(RuntimeError, match="HTTP 503 unavailable"): + backend._provisioner_is_alive("abc123") def test_discover_delegates_to_provisioner_discover(monkeypatch): diff --git a/backend/tests/test_sandbox_middleware.py b/backend/tests/test_sandbox_middleware.py index e3daa3088..c584c759a 100644 --- a/backend/tests/test_sandbox_middleware.py +++ b/backend/tests/test_sandbox_middleware.py @@ -5,7 +5,10 @@ import asyncio import pytest from langchain.agents.middleware import AgentMiddleware from langchain.tools import ToolRuntime +from langchain_core.messages import ToolMessage +from langgraph.prebuilt.tool_node import ToolCallRequest from langgraph.runtime import Runtime +from langgraph.types import Command from deerflow.sandbox.middleware import SandboxMiddleware from deerflow.sandbox.sandbox import Sandbox @@ -223,3 +226,183 @@ async def test_aafter_agent_delegates_to_super_when_no_sandbox(monkeypatch: pyte assert result == {"delegated": True} assert calls == [(state, runtime)] + + +# --------------------------------------------------------------------------- +# wrap_tool_call / awrap_tool_call: persistent sandbox state via Command +# --------------------------------------------------------------------------- + + +def _make_tool_call_request(state: dict) -> ToolCallRequest: + """Build a minimal ToolCallRequest backed by a real ToolRuntime.""" + runtime = ToolRuntime( + state=state, + context={}, + config={"configurable": {}}, + stream_writer=lambda _: None, + tools=[], + tool_call_id="call-1", + store=None, + ) + return ToolCallRequest( + tool_call={"id": "call-1", "name": "bash", "args": {}}, + tool=None, + state=state, + runtime=runtime, + ) + + +def test_wrap_tool_call_emits_command_when_lazy_init_happens() -> None: + middleware = SandboxMiddleware() + state: dict = {} + request = _make_tool_call_request(state) + + def handler(req: ToolCallRequest) -> ToolMessage: + # Simulate ensure_sandbox_initialized() mutating runtime.state in-place. + req.runtime.state["sandbox"] = {"sandbox_id": "new-sandbox"} + return ToolMessage(content="ok", tool_call_id="call-1", name="bash") + + result = middleware.wrap_tool_call(request, handler) + + assert isinstance(result, Command) + assert isinstance(result.update, dict) + assert result.update["sandbox"] == {"sandbox_id": "new-sandbox"} + messages = result.update["messages"] + assert len(messages) == 1 + assert messages[0].content == "ok" + assert messages[0].tool_call_id == "call-1" + + +def test_wrap_tool_call_passthrough_when_sandbox_already_in_state() -> None: + middleware = SandboxMiddleware() + state: dict = {"sandbox": {"sandbox_id": "existing"}} + request = _make_tool_call_request(state) + original = ToolMessage(content="ok", tool_call_id="call-1", name="bash") + + def handler(req: ToolCallRequest) -> ToolMessage: + return original + + result = middleware.wrap_tool_call(request, handler) + + assert result is original + + +def test_wrap_tool_call_passthrough_when_handler_did_not_initialize_sandbox() -> None: + middleware = SandboxMiddleware() + state: dict = {} + request = _make_tool_call_request(state) + original = ToolMessage(content="ok", tool_call_id="call-1", name="bash") + + def handler(req: ToolCallRequest) -> ToolMessage: + return original + + result = middleware.wrap_tool_call(request, handler) + + assert result is original + + +def test_wrap_tool_call_merges_with_existing_command_update() -> None: + middleware = SandboxMiddleware() + state: dict = {} + request = _make_tool_call_request(state) + tool_msg = ToolMessage(content="ok", tool_call_id="call-1", name="bash") + + def handler(req: ToolCallRequest) -> Command: + req.runtime.state["sandbox"] = {"sandbox_id": "new-sandbox"} + return Command( + update={ + "messages": [tool_msg], + "viewed_images": {"a.png": {"base64": "x", "mime_type": "image/png"}}, + }, + goto="next-node", + ) + + result = middleware.wrap_tool_call(request, handler) + + assert isinstance(result, Command) + assert result.goto == "next-node" + assert isinstance(result.update, dict) + assert result.update["messages"] == [tool_msg] + assert result.update["viewed_images"] == {"a.png": {"base64": "x", "mime_type": "image/png"}} + assert result.update["sandbox"] == {"sandbox_id": "new-sandbox"} + + +def test_wrap_tool_call_does_not_override_non_dict_update() -> None: + middleware = SandboxMiddleware() + state: dict = {} + request = _make_tool_call_request(state) + cmd = Command(update=[("messages", [ToolMessage(content="x", tool_call_id="c", name="bash")])]) + + def handler(req: ToolCallRequest) -> Command: + req.runtime.state["sandbox"] = {"sandbox_id": "new-sandbox"} + return cmd + + result = middleware.wrap_tool_call(request, handler) + + # Non-dict update is left untouched to avoid silent data loss. + assert result is cmd + + +@pytest.mark.anyio +async def test_awrap_tool_call_emits_command_when_lazy_init_happens() -> None: + middleware = SandboxMiddleware() + state: dict = {} + request = _make_tool_call_request(state) + + async def handler(req: ToolCallRequest) -> ToolMessage: + req.runtime.state["sandbox"] = {"sandbox_id": "async-new"} + return ToolMessage(content="ok", tool_call_id="call-1", name="bash") + + result = await middleware.awrap_tool_call(request, handler) + + assert isinstance(result, Command) + assert isinstance(result.update, dict) + assert result.update["sandbox"] == {"sandbox_id": "async-new"} + messages = result.update["messages"] + assert len(messages) == 1 + assert messages[0].content == "ok" + + +@pytest.mark.anyio +async def test_awrap_tool_call_passthrough_when_sandbox_already_in_state() -> None: + middleware = SandboxMiddleware() + state: dict = {"sandbox": {"sandbox_id": "existing"}} + request = _make_tool_call_request(state) + original = ToolMessage(content="ok", tool_call_id="call-1", name="bash") + + async def handler(req: ToolCallRequest) -> ToolMessage: + return original + + result = await middleware.awrap_tool_call(request, handler) + + assert result is original + + +def test_wrap_tool_call_preserves_existing_command_fields_when_merging() -> None: + """Regression: when merging sandbox_update into an existing Command, + all other Command fields (e.g. graph, goto, resume) must be preserved. + """ + middleware = SandboxMiddleware() + state: dict = {} + request = _make_tool_call_request(state) + + def handler(req: ToolCallRequest) -> Command: + req.runtime.state["sandbox"] = {"sandbox_id": "sbx-merge"} + return Command( + update={"existing_key": "existing_value"}, + graph="parent", + goto="next_node", + resume="resume-token", + ) + + result = middleware.wrap_tool_call(request, handler) + + assert isinstance(result, Command) + assert result.update == { + "existing_key": "existing_value", + "sandbox": {"sandbox_id": "sbx-merge"}, + } + # Critical: other Command fields must NOT be dropped by the merge. + assert result.graph == "parent" + assert result.goto == "next_node" + assert result.resume == "resume-token"