mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-20 07:01:03 +00:00
e37912e2c8
* Add download interface in Sandbox * fix * fix * del invalidate test * fix * safe download * improve
321 lines
11 KiB
Python
321 lines
11 KiB
Python
"""Tests for AioSandbox concurrent command serialization (#1433)."""
|
|
|
|
import threading
|
|
from types import SimpleNamespace
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
|
|
@pytest.fixture()
|
|
def sandbox():
|
|
"""Create an AioSandbox with a mocked client."""
|
|
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
|
|
from deerflow.community.aio_sandbox.aio_sandbox import AioSandbox
|
|
|
|
sb = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
|
|
return sb
|
|
|
|
|
|
class TestExecuteCommandSerialization:
|
|
"""Verify that concurrent exec_command calls are serialized."""
|
|
|
|
def test_lock_prevents_concurrent_execution(self, sandbox):
|
|
"""Concurrent threads should not overlap inside execute_command."""
|
|
call_log = []
|
|
barrier = threading.Barrier(3)
|
|
|
|
def slow_exec(command, **kwargs):
|
|
call_log.append(("enter", command))
|
|
import time
|
|
|
|
time.sleep(0.05)
|
|
call_log.append(("exit", command))
|
|
return SimpleNamespace(data=SimpleNamespace(output=f"ok: {command}"))
|
|
|
|
sandbox._client.shell.exec_command = slow_exec
|
|
|
|
def worker(cmd):
|
|
barrier.wait() # ensure all threads contend for the lock simultaneously
|
|
sandbox.execute_command(cmd)
|
|
|
|
threads = []
|
|
for i in range(3):
|
|
t = threading.Thread(target=worker, args=(f"cmd-{i}",))
|
|
threads.append(t)
|
|
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
# Verify serialization: each "enter" should be followed by its own
|
|
# "exit" before the next "enter" (no interleaving).
|
|
enters = [i for i, (action, _) in enumerate(call_log) if action == "enter"]
|
|
exits = [i for i, (action, _) in enumerate(call_log) if action == "exit"]
|
|
assert len(enters) == 3
|
|
assert len(exits) == 3
|
|
for e_idx, x_idx in zip(enters, exits):
|
|
assert x_idx == e_idx + 1, f"Interleaved execution detected: {call_log}"
|
|
|
|
|
|
class TestErrorObservationRetry:
|
|
"""Verify ErrorObservation detection and fresh-session retry."""
|
|
|
|
def test_retry_on_error_observation(self, sandbox):
|
|
"""When output contains ErrorObservation, retry with a fresh session."""
|
|
call_count = 0
|
|
|
|
def mock_exec(command, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
|
|
return SimpleNamespace(data=SimpleNamespace(output="success"))
|
|
|
|
sandbox._client.shell.exec_command = mock_exec
|
|
|
|
result = sandbox.execute_command("echo hello")
|
|
assert result == "success"
|
|
assert call_count == 2
|
|
|
|
def test_retry_passes_fresh_session_id(self, sandbox):
|
|
"""The retry call should include a new session id kwarg."""
|
|
calls = []
|
|
|
|
def mock_exec(command, **kwargs):
|
|
calls.append(kwargs)
|
|
if len(calls) == 1:
|
|
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
|
|
return SimpleNamespace(data=SimpleNamespace(output="ok"))
|
|
|
|
sandbox._client.shell.exec_command = mock_exec
|
|
|
|
sandbox.execute_command("test")
|
|
assert len(calls) == 2
|
|
assert "id" not in calls[0]
|
|
assert "id" in calls[1]
|
|
assert len(calls[1]["id"]) == 36 # UUID format
|
|
|
|
def test_no_retry_on_clean_output(self, sandbox):
|
|
"""Normal output should not trigger a retry."""
|
|
call_count = 0
|
|
|
|
def mock_exec(command, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return SimpleNamespace(data=SimpleNamespace(output="all good"))
|
|
|
|
sandbox._client.shell.exec_command = mock_exec
|
|
|
|
result = sandbox.execute_command("echo hello")
|
|
assert result == "all good"
|
|
assert call_count == 1
|
|
|
|
|
|
class TestListDirSerialization:
|
|
"""Verify that list_dir also acquires the lock."""
|
|
|
|
def test_list_dir_uses_lock(self, sandbox):
|
|
"""list_dir should hold the lock during execution."""
|
|
lock_was_held = []
|
|
|
|
original_exec = MagicMock(return_value=SimpleNamespace(data=SimpleNamespace(output="/a\n/b")))
|
|
|
|
def tracking_exec(command, **kwargs):
|
|
lock_was_held.append(sandbox._lock.locked())
|
|
return original_exec(command, **kwargs)
|
|
|
|
sandbox._client.shell.exec_command = tracking_exec
|
|
|
|
result = sandbox.list_dir("/test")
|
|
assert result == ["/a", "/b"]
|
|
assert lock_was_held == [True], "list_dir must hold the lock during exec_command"
|
|
|
|
|
|
class TestNoChangeTimeout:
|
|
"""Verify that no_change_timeout is forwarded to every exec_command call."""
|
|
|
|
def test_execute_command_passes_no_change_timeout(self, sandbox):
|
|
"""execute_command should pass no_change_timeout to exec_command."""
|
|
calls = []
|
|
|
|
def mock_exec(command, **kwargs):
|
|
calls.append(kwargs)
|
|
return SimpleNamespace(data=SimpleNamespace(output="ok"))
|
|
|
|
sandbox._client.shell.exec_command = mock_exec
|
|
|
|
sandbox.execute_command("echo hello")
|
|
|
|
assert len(calls) == 1
|
|
assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
|
|
|
|
def test_retry_passes_no_change_timeout(self, sandbox):
|
|
"""The ErrorObservation retry path should also pass no_change_timeout."""
|
|
calls = []
|
|
|
|
def mock_exec(command, **kwargs):
|
|
calls.append(kwargs)
|
|
if len(calls) == 1:
|
|
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
|
|
return SimpleNamespace(data=SimpleNamespace(output="ok"))
|
|
|
|
sandbox._client.shell.exec_command = mock_exec
|
|
|
|
sandbox.execute_command("echo hello")
|
|
|
|
assert len(calls) == 2
|
|
assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
|
|
assert calls[1].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
|
|
|
|
def test_list_dir_passes_no_change_timeout(self, sandbox):
|
|
"""list_dir should pass no_change_timeout to exec_command."""
|
|
calls = []
|
|
|
|
def mock_exec(command, **kwargs):
|
|
calls.append(kwargs)
|
|
return SimpleNamespace(data=SimpleNamespace(output="/a\n/b"))
|
|
|
|
sandbox._client.shell.exec_command = mock_exec
|
|
|
|
sandbox.list_dir("/test")
|
|
|
|
assert len(calls) == 1
|
|
assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
|
|
|
|
|
|
class TestConcurrentFileWrites:
|
|
"""Verify file write paths do not lose concurrent updates."""
|
|
|
|
def test_append_should_preserve_both_parallel_writes(self, sandbox):
|
|
storage = {"content": "seed\n"}
|
|
active_reads = 0
|
|
state_lock = threading.Lock()
|
|
overlap_detected = threading.Event()
|
|
|
|
def overlapping_read_file(path):
|
|
nonlocal active_reads
|
|
with state_lock:
|
|
active_reads += 1
|
|
snapshot = storage["content"]
|
|
if active_reads == 2:
|
|
overlap_detected.set()
|
|
|
|
overlap_detected.wait(0.05)
|
|
|
|
with state_lock:
|
|
active_reads -= 1
|
|
|
|
return snapshot
|
|
|
|
def write_back(*, file, content, **kwargs):
|
|
storage["content"] = content
|
|
return SimpleNamespace(data=SimpleNamespace())
|
|
|
|
sandbox.read_file = overlapping_read_file
|
|
sandbox._client.file.write_file = write_back
|
|
|
|
barrier = threading.Barrier(2)
|
|
|
|
def writer(payload: str):
|
|
barrier.wait()
|
|
sandbox.write_file("/tmp/shared.log", payload, append=True)
|
|
|
|
threads = [
|
|
threading.Thread(target=writer, args=("A\n",)),
|
|
threading.Thread(target=writer, args=("B\n",)),
|
|
]
|
|
|
|
for thread in threads:
|
|
thread.start()
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
assert storage["content"] in {"seed\nA\nB\n", "seed\nB\nA\n"}
|
|
|
|
|
|
class TestDownloadFile:
|
|
"""Tests for AioSandbox.download_file."""
|
|
|
|
def test_returns_concatenated_bytes(self, sandbox):
|
|
"""download_file should join chunks from the client iterator into bytes."""
|
|
sandbox._client.file.download_file = MagicMock(return_value=[b"hel", b"lo"])
|
|
|
|
result = sandbox.download_file("/mnt/user-data/outputs/file.bin")
|
|
|
|
assert result == b"hello"
|
|
sandbox._client.file.download_file.assert_called_once_with(path="/mnt/user-data/outputs/file.bin")
|
|
|
|
def test_returns_empty_bytes_for_empty_file(self, sandbox):
|
|
"""download_file should return b'' when the iterator yields nothing."""
|
|
sandbox._client.file.download_file = MagicMock(return_value=iter([]))
|
|
|
|
result = sandbox.download_file("/mnt/user-data/outputs/empty.bin")
|
|
|
|
assert result == b""
|
|
|
|
def test_uses_lock_during_download(self, sandbox):
|
|
"""download_file should hold the lock while calling the client."""
|
|
lock_was_held = []
|
|
|
|
def tracking_download(path):
|
|
lock_was_held.append(sandbox._lock.locked())
|
|
return iter([b"data"])
|
|
|
|
sandbox._client.file.download_file = tracking_download
|
|
|
|
sandbox.download_file("/mnt/user-data/outputs/file.bin")
|
|
|
|
assert lock_was_held == [True], "download_file must hold the lock during client call"
|
|
|
|
def test_raises_oserror_on_client_error(self, sandbox):
|
|
"""download_file should wrap client exceptions as OSError."""
|
|
sandbox._client.file.download_file = MagicMock(side_effect=RuntimeError("network error"))
|
|
|
|
with pytest.raises(OSError, match="network error"):
|
|
sandbox.download_file("/mnt/user-data/outputs/file.bin")
|
|
|
|
def test_preserves_oserror_from_client(self, sandbox):
|
|
"""OSError raised by the client should propagate without re-wrapping."""
|
|
sandbox._client.file.download_file = MagicMock(side_effect=OSError("disk error"))
|
|
|
|
with pytest.raises(OSError, match="disk error"):
|
|
sandbox.download_file("/mnt/user-data/outputs/file.bin")
|
|
|
|
def test_rejects_path_outside_virtual_prefix_and_logs_error(self, sandbox, caplog):
|
|
"""download_file must reject downloads outside /mnt/user-data and log the reason."""
|
|
sandbox._client.file.download_file = MagicMock()
|
|
|
|
with caplog.at_level("ERROR"):
|
|
with pytest.raises(PermissionError, match="must be under"):
|
|
sandbox.download_file("/etc/passwd")
|
|
|
|
assert "outside allowed directory" in caplog.text
|
|
sandbox._client.file.download_file.assert_not_called()
|
|
|
|
@pytest.mark.parametrize(
|
|
"path",
|
|
[
|
|
"/mnt/workspace/../../etc/passwd",
|
|
"../secret",
|
|
"/a/b/../../../etc/shadow",
|
|
],
|
|
)
|
|
def test_rejects_path_traversal(self, sandbox, path):
|
|
"""download_file must reject paths containing '..' before calling the client."""
|
|
sandbox._client.file.download_file = MagicMock()
|
|
|
|
with pytest.raises(PermissionError, match="path traversal"):
|
|
sandbox.download_file(path)
|
|
|
|
sandbox._client.file.download_file.assert_not_called()
|
|
|
|
def test_single_chunk(self, sandbox):
|
|
"""download_file should work correctly with a single-chunk response."""
|
|
sandbox._client.file.download_file = MagicMock(return_value=[b"single-chunk"])
|
|
|
|
result = sandbox.download_file("/mnt/user-data/outputs/single.bin")
|
|
|
|
assert result == b"single-chunk"
|