Files
deer-flow/backend/tests/test_aio_sandbox.py
T
stphtt f212da9f89 fix(sandbox): create shell session before retrying on a fresh id (#3577)
* fix(sandbox): create shell session before retrying on a fresh id

The AIO sandbox recovery path generated a UUID and passed it straight to
exec_command(id=...). The sandbox image only auto-creates a session when
exec_command is called with *no* id; an exec carrying an unknown id returns
HTTP 404 "Session not found". So every ErrorObservation recovery itself
404'd, turning a transient session lapse into an unrecoverable tool error
that looped the run up to the LangGraph recursion limit.

Explicitly create_session(id=fresh_id) before targeting that id on retry.
create_session is idempotent (returns the existing session if the id already
exists), so this is safe under the serializing lock.

Updated the regression test to assert the retry targets exactly the
created session id rather than a fabricated, uncreated one.

* fix(sandbox): release the one-shot recovery session after retry

The fresh session created on the ErrorObservation recovery path is used for
exactly one command -- the next execute_command runs with no id and returns
to the default session. Under persistent session corruption every command
would create another session that is never reused or released, accumulating
sessions on the container.

Release it best-effort with cleanup_session() in a finally, swallowing any
cleanup error so it never masks a successful retry.

Addresses review feedback on #3577.

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-06-17 10:21:27 +08:00

442 lines
17 KiB
Python

"""Tests for AioSandbox concurrent command serialization (#1433)."""
import threading
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
@pytest.fixture()
def sandbox():
"""Create an AioSandbox with a mocked client."""
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
from deerflow.community.aio_sandbox.aio_sandbox import AioSandbox
sb = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
return sb
class TestExecuteCommandSerialization:
"""Verify that concurrent exec_command calls are serialized."""
def test_lock_prevents_concurrent_execution(self, sandbox):
"""Concurrent threads should not overlap inside execute_command."""
call_log = []
barrier = threading.Barrier(3)
def slow_exec(command, **kwargs):
call_log.append(("enter", command))
import time
time.sleep(0.05)
call_log.append(("exit", command))
return SimpleNamespace(data=SimpleNamespace(output=f"ok: {command}"))
sandbox._client.shell.exec_command = slow_exec
def worker(cmd):
barrier.wait() # ensure all threads contend for the lock simultaneously
sandbox.execute_command(cmd)
threads = []
for i in range(3):
t = threading.Thread(target=worker, args=(f"cmd-{i}",))
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
# Verify serialization: each "enter" should be followed by its own
# "exit" before the next "enter" (no interleaving).
enters = [i for i, (action, _) in enumerate(call_log) if action == "enter"]
exits = [i for i, (action, _) in enumerate(call_log) if action == "exit"]
assert len(enters) == 3
assert len(exits) == 3
for e_idx, x_idx in zip(enters, exits):
assert x_idx == e_idx + 1, f"Interleaved execution detected: {call_log}"
class TestErrorObservationRetry:
"""Verify ErrorObservation detection and fresh-session retry."""
def test_retry_on_error_observation(self, sandbox):
"""When output contains ErrorObservation, retry with a fresh session."""
call_count = 0
def mock_exec(command, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
return SimpleNamespace(data=SimpleNamespace(output="success"))
sandbox._client.shell.exec_command = mock_exec
result = sandbox.execute_command("echo hello")
assert result == "success"
assert call_count == 2
def test_retry_creates_fresh_session_before_targeting_it(self, sandbox):
"""Recovery must explicitly create a session, then exec against that id.
The sandbox image only auto-creates a session when exec_command is
called with *no* id; an exec carrying an unknown id returns HTTP 404
"Session not found". So the retry must obtain a real, distinct session
via create_session() first and target that id, rather than fabricating
an id and handing it straight to exec_command (the regression that
404'd every recovery and looped runs to the recursion limit).
"""
exec_calls = []
created_ids = []
cleaned_ids = []
def mock_exec(command, **kwargs):
exec_calls.append(kwargs)
if len(exec_calls) == 1:
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
return SimpleNamespace(data=SimpleNamespace(output="ok"))
def mock_create_session(id, **kwargs):
created_ids.append(id)
return SimpleNamespace(data=SimpleNamespace(session_id=id))
def mock_cleanup_session(session_id, **kwargs):
cleaned_ids.append(session_id)
sandbox._client.shell.exec_command = mock_exec
sandbox._client.shell.create_session = mock_create_session
sandbox._client.shell.cleanup_session = mock_cleanup_session
result = sandbox.execute_command("test")
assert result == "ok"
assert len(exec_calls) == 2
# First attempt runs on the default session (no id).
assert "id" not in exec_calls[0]
# A fresh session was explicitly created...
assert len(created_ids) == 1
assert len(created_ids[0]) == 36 # UUID format
# ...and the retry targets exactly that created session, never an
# uncreated/fabricated id (which would 404).
assert exec_calls[1].get("id") == created_ids[0]
# ...and that one-shot recovery session is released afterwards so a
# sandbox that keeps hitting corruption doesn't accumulate sessions.
assert cleaned_ids == [created_ids[0]]
def test_cleanup_failure_does_not_mask_successful_retry(self, sandbox):
"""A failure releasing the recovery session must not lose the retry output."""
def mock_exec(command, **kwargs):
if "id" not in kwargs:
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
return SimpleNamespace(data=SimpleNamespace(output="recovered"))
def mock_cleanup_session(session_id, **kwargs):
raise RuntimeError("cleanup boom")
sandbox._client.shell.exec_command = mock_exec
sandbox._client.shell.create_session = lambda id, **kwargs: SimpleNamespace(data=SimpleNamespace(session_id=id))
sandbox._client.shell.cleanup_session = mock_cleanup_session
# The retry succeeded; the swallowed cleanup error must not turn this
# into an "Error: ..." result.
assert sandbox.execute_command("test") == "recovered"
def test_no_retry_on_clean_output(self, sandbox):
"""Normal output should not trigger a retry."""
call_count = 0
def mock_exec(command, **kwargs):
nonlocal call_count
call_count += 1
return SimpleNamespace(data=SimpleNamespace(output="all good"))
sandbox._client.shell.exec_command = mock_exec
result = sandbox.execute_command("echo hello")
assert result == "all good"
assert call_count == 1
class TestListDirSerialization:
"""Verify that list_dir also acquires the lock."""
def test_list_dir_uses_lock(self, sandbox):
"""list_dir should hold the lock during execution."""
lock_was_held = []
original_exec = MagicMock(return_value=SimpleNamespace(data=SimpleNamespace(output="/a\n/b")))
def tracking_exec(command, **kwargs):
lock_was_held.append(sandbox._lock.locked())
return original_exec(command, **kwargs)
sandbox._client.shell.exec_command = tracking_exec
result = sandbox.list_dir("/test")
assert result == ["/a", "/b"]
assert lock_was_held == [True], "list_dir must hold the lock during exec_command"
class TestNoChangeTimeout:
"""Verify that no_change_timeout is forwarded to every exec_command call."""
def test_execute_command_passes_no_change_timeout(self, sandbox):
"""execute_command should pass no_change_timeout to exec_command."""
calls = []
def mock_exec(command, **kwargs):
calls.append(kwargs)
return SimpleNamespace(data=SimpleNamespace(output="ok"))
sandbox._client.shell.exec_command = mock_exec
sandbox.execute_command("echo hello")
assert len(calls) == 1
assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
def test_retry_passes_no_change_timeout(self, sandbox):
"""The ErrorObservation retry path should also pass no_change_timeout."""
calls = []
def mock_exec(command, **kwargs):
calls.append(kwargs)
if len(calls) == 1:
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
return SimpleNamespace(data=SimpleNamespace(output="ok"))
sandbox._client.shell.exec_command = mock_exec
sandbox.execute_command("echo hello")
assert len(calls) == 2
assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
assert calls[1].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
def test_list_dir_passes_no_change_timeout(self, sandbox):
"""list_dir should pass no_change_timeout to exec_command."""
calls = []
def mock_exec(command, **kwargs):
calls.append(kwargs)
return SimpleNamespace(data=SimpleNamespace(output="/a\n/b"))
sandbox._client.shell.exec_command = mock_exec
sandbox.list_dir("/test")
assert len(calls) == 1
assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
class TestConcurrentFileWrites:
"""Verify file write paths do not lose concurrent updates."""
def test_append_should_preserve_both_parallel_writes(self, sandbox):
storage = {"content": "seed\n"}
active_reads = 0
state_lock = threading.Lock()
overlap_detected = threading.Event()
def overlapping_read_file(path):
nonlocal active_reads
with state_lock:
active_reads += 1
snapshot = storage["content"]
if active_reads == 2:
overlap_detected.set()
overlap_detected.wait(0.05)
with state_lock:
active_reads -= 1
return snapshot
def write_back(*, file, content, **kwargs):
storage["content"] = content
return SimpleNamespace(data=SimpleNamespace())
sandbox.read_file = overlapping_read_file
sandbox._client.file.write_file = write_back
barrier = threading.Barrier(2)
def writer(payload: str):
barrier.wait()
sandbox.write_file("/tmp/shared.log", payload, append=True)
threads = [
threading.Thread(target=writer, args=("A\n",)),
threading.Thread(target=writer, args=("B\n",)),
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert storage["content"] in {"seed\nA\nB\n", "seed\nB\nA\n"}
class TestDownloadFile:
"""Tests for AioSandbox.download_file."""
def test_returns_concatenated_bytes(self, sandbox):
"""download_file should join chunks from the client iterator into bytes."""
sandbox._client.file.download_file = MagicMock(return_value=[b"hel", b"lo"])
result = sandbox.download_file("/mnt/user-data/outputs/file.bin")
assert result == b"hello"
sandbox._client.file.download_file.assert_called_once_with(path="/mnt/user-data/outputs/file.bin")
def test_returns_empty_bytes_for_empty_file(self, sandbox):
"""download_file should return b'' when the iterator yields nothing."""
sandbox._client.file.download_file = MagicMock(return_value=iter([]))
result = sandbox.download_file("/mnt/user-data/outputs/empty.bin")
assert result == b""
def test_uses_lock_during_download(self, sandbox):
"""download_file should hold the lock while calling the client."""
lock_was_held = []
def tracking_download(path):
lock_was_held.append(sandbox._lock.locked())
return iter([b"data"])
sandbox._client.file.download_file = tracking_download
sandbox.download_file("/mnt/user-data/outputs/file.bin")
assert lock_was_held == [True], "download_file must hold the lock during client call"
def test_raises_oserror_on_client_error(self, sandbox):
"""download_file should wrap client exceptions as OSError."""
sandbox._client.file.download_file = MagicMock(side_effect=RuntimeError("network error"))
with pytest.raises(OSError, match="network error"):
sandbox.download_file("/mnt/user-data/outputs/file.bin")
def test_preserves_oserror_from_client(self, sandbox):
"""OSError raised by the client should propagate without re-wrapping."""
sandbox._client.file.download_file = MagicMock(side_effect=OSError("disk error"))
with pytest.raises(OSError, match="disk error"):
sandbox.download_file("/mnt/user-data/outputs/file.bin")
def test_rejects_path_outside_virtual_prefix_and_logs_error(self, sandbox, caplog):
"""download_file must reject downloads outside /mnt/user-data and log the reason."""
sandbox._client.file.download_file = MagicMock()
with caplog.at_level("ERROR"):
with pytest.raises(PermissionError, match="must be under"):
sandbox.download_file("/etc/passwd")
assert "outside allowed directory" in caplog.text
sandbox._client.file.download_file.assert_not_called()
@pytest.mark.parametrize(
"path",
[
"/mnt/workspace/../../etc/passwd",
"../secret",
"/a/b/../../../etc/shadow",
],
)
def test_rejects_path_traversal(self, sandbox, path):
"""download_file must reject paths containing '..' before calling the client."""
sandbox._client.file.download_file = MagicMock()
with pytest.raises(PermissionError, match="path traversal"):
sandbox.download_file(path)
sandbox._client.file.download_file.assert_not_called()
def test_single_chunk(self, sandbox):
"""download_file should work correctly with a single-chunk response."""
sandbox._client.file.download_file = MagicMock(return_value=[b"single-chunk"])
result = sandbox.download_file("/mnt/user-data/outputs/single.bin")
assert result == b"single-chunk"
class TestClose:
"""Verify AioSandbox.close() tears down the host-side HTTP client (#2872)."""
def test_close_calls_real_nested_httpx_client(self, sandbox):
"""close() must close the real httpx.Client at the bottom of the chain.
Mirrors the actual Fern structure:
Sandbox._client_wrapper.httpx_client -> Fern HttpClient (no close())
.httpx_client -> httpx.Client (the real owner)
The intermediate HttpClient deliberately exposes NO close(), so a naive
one-level lookup (the original bug) would silently close nothing.
"""
real_httpx = MagicMock(spec=["close"])
fern_http = SimpleNamespace(httpx_client=real_httpx) # no close on this layer
sandbox._client._client_wrapper = SimpleNamespace(httpx_client=fern_http)
sandbox.close()
real_httpx.close.assert_called_once_with()
def test_close_clears_client_reference(self, sandbox):
"""After close(), the client reference must be dropped (use-after-close safety)."""
real_httpx = MagicMock(spec=["close"])
fern_http = SimpleNamespace(httpx_client=real_httpx)
sandbox._client._client_wrapper = SimpleNamespace(httpx_client=fern_http)
sandbox.close()
assert sandbox._client is None
assert sandbox._closed is True
def test_close_is_idempotent(self, sandbox):
"""Calling close() multiple times must close the underlying client at most once."""
real_httpx = MagicMock(spec=["close"])
fern_http = SimpleNamespace(httpx_client=real_httpx)
sandbox._client._client_wrapper = SimpleNamespace(httpx_client=fern_http)
sandbox.close()
sandbox.close()
sandbox.close()
assert real_httpx.close.call_count == 1
def test_close_swallows_exceptions(self, sandbox, caplog):
"""close() must be best-effort: client errors are logged but never raised."""
real_httpx = MagicMock(spec=["close"])
real_httpx.close.side_effect = RuntimeError("teardown boom")
fern_http = SimpleNamespace(httpx_client=real_httpx)
sandbox._client._client_wrapper = SimpleNamespace(httpx_client=fern_http)
with caplog.at_level("WARNING"):
sandbox.close()
assert "Error closing AioSandbox client" in caplog.text
def test_close_falls_back_to_client_close(self, sandbox):
"""If no nested httpx.Client is reachable, close() degrades to the client's own close()."""
# Replace the mocked client with a stub that exposes only top-level close()
client = MagicMock(spec=["close"])
sandbox._client = client
sandbox.close()
client.close.assert_called_once_with()
def test_close_when_no_close_attr_does_not_raise(self, sandbox):
"""A client without any close attribute must not crash close()."""
sandbox._client = SimpleNamespace() # no close, no _client_wrapper
sandbox.close() # must not raise
assert sandbox._client is None