"""Tests for AioSandbox concurrent command serialization (#1433).""" import threading from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest @pytest.fixture() def sandbox(): """Create an AioSandbox with a mocked client.""" with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"): from deerflow.community.aio_sandbox.aio_sandbox import AioSandbox sb = AioSandbox(id="test-sandbox", base_url="http://localhost:8080") return sb class TestExecuteCommandSerialization: """Verify that concurrent exec_command calls are serialized.""" def test_lock_prevents_concurrent_execution(self, sandbox): """Concurrent threads should not overlap inside execute_command.""" call_log = [] barrier = threading.Barrier(3) def slow_exec(command, **kwargs): call_log.append(("enter", command)) import time time.sleep(0.05) call_log.append(("exit", command)) return SimpleNamespace(data=SimpleNamespace(output=f"ok: {command}")) sandbox._client.shell.exec_command = slow_exec def worker(cmd): barrier.wait() # ensure all threads contend for the lock simultaneously sandbox.execute_command(cmd) threads = [] for i in range(3): t = threading.Thread(target=worker, args=(f"cmd-{i}",)) threads.append(t) for t in threads: t.start() for t in threads: t.join() # Verify serialization: each "enter" should be followed by its own # "exit" before the next "enter" (no interleaving). enters = [i for i, (action, _) in enumerate(call_log) if action == "enter"] exits = [i for i, (action, _) in enumerate(call_log) if action == "exit"] assert len(enters) == 3 assert len(exits) == 3 for e_idx, x_idx in zip(enters, exits): assert x_idx == e_idx + 1, f"Interleaved execution detected: {call_log}" class TestErrorObservationRetry: """Verify ErrorObservation detection and fresh-session retry.""" def test_retry_on_error_observation(self, sandbox): """When output contains ErrorObservation, retry with a fresh session.""" call_count = 0 def mock_exec(command, **kwargs): nonlocal call_count call_count += 1 if call_count == 1: return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'")) return SimpleNamespace(data=SimpleNamespace(output="success")) sandbox._client.shell.exec_command = mock_exec result = sandbox.execute_command("echo hello") assert result == "success" assert call_count == 2 def test_retry_passes_fresh_session_id(self, sandbox): """The retry call should include a new session id kwarg.""" calls = [] def mock_exec(command, **kwargs): calls.append(kwargs) if len(calls) == 1: return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'")) return SimpleNamespace(data=SimpleNamespace(output="ok")) sandbox._client.shell.exec_command = mock_exec sandbox.execute_command("test") assert len(calls) == 2 assert "id" not in calls[0] assert "id" in calls[1] assert len(calls[1]["id"]) == 36 # UUID format def test_no_retry_on_clean_output(self, sandbox): """Normal output should not trigger a retry.""" call_count = 0 def mock_exec(command, **kwargs): nonlocal call_count call_count += 1 return SimpleNamespace(data=SimpleNamespace(output="all good")) sandbox._client.shell.exec_command = mock_exec result = sandbox.execute_command("echo hello") assert result == "all good" assert call_count == 1 class TestListDirSerialization: """Verify that list_dir also acquires the lock.""" def test_list_dir_uses_lock(self, sandbox): """list_dir should hold the lock during execution.""" lock_was_held = [] original_exec = MagicMock(return_value=SimpleNamespace(data=SimpleNamespace(output="/a\n/b"))) def tracking_exec(command, **kwargs): lock_was_held.append(sandbox._lock.locked()) return original_exec(command, **kwargs) sandbox._client.shell.exec_command = tracking_exec result = sandbox.list_dir("/test") assert result == ["/a", "/b"] assert lock_was_held == [True], "list_dir must hold the lock during exec_command" class TestNoChangeTimeout: """Verify that no_change_timeout is forwarded to every exec_command call.""" def test_execute_command_passes_no_change_timeout(self, sandbox): """execute_command should pass no_change_timeout to exec_command.""" calls = [] def mock_exec(command, **kwargs): calls.append(kwargs) return SimpleNamespace(data=SimpleNamespace(output="ok")) sandbox._client.shell.exec_command = mock_exec sandbox.execute_command("echo hello") assert len(calls) == 1 assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT def test_retry_passes_no_change_timeout(self, sandbox): """The ErrorObservation retry path should also pass no_change_timeout.""" calls = [] def mock_exec(command, **kwargs): calls.append(kwargs) if len(calls) == 1: return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'")) return SimpleNamespace(data=SimpleNamespace(output="ok")) sandbox._client.shell.exec_command = mock_exec sandbox.execute_command("echo hello") assert len(calls) == 2 assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT assert calls[1].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT def test_list_dir_passes_no_change_timeout(self, sandbox): """list_dir should pass no_change_timeout to exec_command.""" calls = [] def mock_exec(command, **kwargs): calls.append(kwargs) return SimpleNamespace(data=SimpleNamespace(output="/a\n/b")) sandbox._client.shell.exec_command = mock_exec sandbox.list_dir("/test") assert len(calls) == 1 assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT class TestConcurrentFileWrites: """Verify file write paths do not lose concurrent updates.""" def test_append_should_preserve_both_parallel_writes(self, sandbox): storage = {"content": "seed\n"} active_reads = 0 state_lock = threading.Lock() overlap_detected = threading.Event() def overlapping_read_file(path): nonlocal active_reads with state_lock: active_reads += 1 snapshot = storage["content"] if active_reads == 2: overlap_detected.set() overlap_detected.wait(0.05) with state_lock: active_reads -= 1 return snapshot def write_back(*, file, content, **kwargs): storage["content"] = content return SimpleNamespace(data=SimpleNamespace()) sandbox.read_file = overlapping_read_file sandbox._client.file.write_file = write_back barrier = threading.Barrier(2) def writer(payload: str): barrier.wait() sandbox.write_file("/tmp/shared.log", payload, append=True) threads = [ threading.Thread(target=writer, args=("A\n",)), threading.Thread(target=writer, args=("B\n",)), ] for thread in threads: thread.start() for thread in threads: thread.join() assert storage["content"] in {"seed\nA\nB\n", "seed\nB\nA\n"} class TestDownloadFile: """Tests for AioSandbox.download_file.""" def test_returns_concatenated_bytes(self, sandbox): """download_file should join chunks from the client iterator into bytes.""" sandbox._client.file.download_file = MagicMock(return_value=[b"hel", b"lo"]) result = sandbox.download_file("/mnt/user-data/outputs/file.bin") assert result == b"hello" sandbox._client.file.download_file.assert_called_once_with(path="/mnt/user-data/outputs/file.bin") def test_returns_empty_bytes_for_empty_file(self, sandbox): """download_file should return b'' when the iterator yields nothing.""" sandbox._client.file.download_file = MagicMock(return_value=iter([])) result = sandbox.download_file("/mnt/user-data/outputs/empty.bin") assert result == b"" def test_uses_lock_during_download(self, sandbox): """download_file should hold the lock while calling the client.""" lock_was_held = [] def tracking_download(path): lock_was_held.append(sandbox._lock.locked()) return iter([b"data"]) sandbox._client.file.download_file = tracking_download sandbox.download_file("/mnt/user-data/outputs/file.bin") assert lock_was_held == [True], "download_file must hold the lock during client call" def test_raises_oserror_on_client_error(self, sandbox): """download_file should wrap client exceptions as OSError.""" sandbox._client.file.download_file = MagicMock(side_effect=RuntimeError("network error")) with pytest.raises(OSError, match="network error"): sandbox.download_file("/mnt/user-data/outputs/file.bin") def test_preserves_oserror_from_client(self, sandbox): """OSError raised by the client should propagate without re-wrapping.""" sandbox._client.file.download_file = MagicMock(side_effect=OSError("disk error")) with pytest.raises(OSError, match="disk error"): sandbox.download_file("/mnt/user-data/outputs/file.bin") def test_rejects_path_outside_virtual_prefix_and_logs_error(self, sandbox, caplog): """download_file must reject downloads outside /mnt/user-data and log the reason.""" sandbox._client.file.download_file = MagicMock() with caplog.at_level("ERROR"): with pytest.raises(PermissionError, match="must be under"): sandbox.download_file("/etc/passwd") assert "outside allowed directory" in caplog.text sandbox._client.file.download_file.assert_not_called() @pytest.mark.parametrize( "path", [ "/mnt/workspace/../../etc/passwd", "../secret", "/a/b/../../../etc/shadow", ], ) def test_rejects_path_traversal(self, sandbox, path): """download_file must reject paths containing '..' before calling the client.""" sandbox._client.file.download_file = MagicMock() with pytest.raises(PermissionError, match="path traversal"): sandbox.download_file(path) sandbox._client.file.download_file.assert_not_called() def test_single_chunk(self, sandbox): """download_file should work correctly with a single-chunk response.""" sandbox._client.file.download_file = MagicMock(return_value=[b"single-chunk"]) result = sandbox.download_file("/mnt/user-data/outputs/single.bin") assert result == b"single-chunk"