mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 15:36:48 +00:00
fix(sandbox): serialize concurrent exec_command calls in AioSandbox (#1435)
* fix(sandbox): serialize concurrent exec_command calls in AioSandbox The AIO sandbox container maintains a single persistent shell session that corrupts when multiple exec_command requests arrive concurrently (e.g. when ToolNode issues parallel tool_calls). The corrupted session returns 'ErrorObservation' strings as output, cascading into subsequent commands. Add a threading.Lock to AioSandbox to serialize shell commands. As a secondary defense, detect ErrorObservation in output and retry with a fresh session ID. Fixes #1433 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(sandbox): address Copilot review findings - Fix shell injection in list_dir: use shlex.quote(path) to escape user-provided paths in the find command - Narrow ErrorObservation retry condition from broad substring match to the specific corruption signature to prevent false retries - Improve test_lock_prevents_concurrent_execution: use threading.Barrier to ensure all workers contend for the lock simultaneously - Improve test_list_dir_uses_lock: assert lock.locked() is True during exec_command to verify lock acquisition * style: auto-format with ruff --------- Co-authored-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,133 @@
|
||||
"""Tests for AioSandbox concurrent command serialization (#1433)."""
|
||||
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def sandbox():
|
||||
"""Create an AioSandbox with a mocked client."""
|
||||
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
|
||||
from deerflow.community.aio_sandbox.aio_sandbox import AioSandbox
|
||||
|
||||
sb = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
|
||||
return sb
|
||||
|
||||
|
||||
class TestExecuteCommandSerialization:
|
||||
"""Verify that concurrent exec_command calls are serialized."""
|
||||
|
||||
def test_lock_prevents_concurrent_execution(self, sandbox):
|
||||
"""Concurrent threads should not overlap inside execute_command."""
|
||||
call_log = []
|
||||
barrier = threading.Barrier(3)
|
||||
|
||||
def slow_exec(command, **kwargs):
|
||||
call_log.append(("enter", command))
|
||||
import time
|
||||
|
||||
time.sleep(0.05)
|
||||
call_log.append(("exit", command))
|
||||
return SimpleNamespace(data=SimpleNamespace(output=f"ok: {command}"))
|
||||
|
||||
sandbox._client.shell.exec_command = slow_exec
|
||||
|
||||
def worker(cmd):
|
||||
barrier.wait() # ensure all threads contend for the lock simultaneously
|
||||
sandbox.execute_command(cmd)
|
||||
|
||||
threads = []
|
||||
for i in range(3):
|
||||
t = threading.Thread(target=worker, args=(f"cmd-{i}",))
|
||||
threads.append(t)
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Verify serialization: each "enter" should be followed by its own
|
||||
# "exit" before the next "enter" (no interleaving).
|
||||
enters = [i for i, (action, _) in enumerate(call_log) if action == "enter"]
|
||||
exits = [i for i, (action, _) in enumerate(call_log) if action == "exit"]
|
||||
assert len(enters) == 3
|
||||
assert len(exits) == 3
|
||||
for e_idx, x_idx in zip(enters, exits):
|
||||
assert x_idx == e_idx + 1, f"Interleaved execution detected: {call_log}"
|
||||
|
||||
|
||||
class TestErrorObservationRetry:
|
||||
"""Verify ErrorObservation detection and fresh-session retry."""
|
||||
|
||||
def test_retry_on_error_observation(self, sandbox):
|
||||
"""When output contains ErrorObservation, retry with a fresh session."""
|
||||
call_count = 0
|
||||
|
||||
def mock_exec(command, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
|
||||
return SimpleNamespace(data=SimpleNamespace(output="success"))
|
||||
|
||||
sandbox._client.shell.exec_command = mock_exec
|
||||
|
||||
result = sandbox.execute_command("echo hello")
|
||||
assert result == "success"
|
||||
assert call_count == 2
|
||||
|
||||
def test_retry_passes_fresh_session_id(self, sandbox):
|
||||
"""The retry call should include a new session id kwarg."""
|
||||
calls = []
|
||||
|
||||
def mock_exec(command, **kwargs):
|
||||
calls.append(kwargs)
|
||||
if len(calls) == 1:
|
||||
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
|
||||
return SimpleNamespace(data=SimpleNamespace(output="ok"))
|
||||
|
||||
sandbox._client.shell.exec_command = mock_exec
|
||||
|
||||
sandbox.execute_command("test")
|
||||
assert len(calls) == 2
|
||||
assert "id" not in calls[0]
|
||||
assert "id" in calls[1]
|
||||
assert len(calls[1]["id"]) == 36 # UUID format
|
||||
|
||||
def test_no_retry_on_clean_output(self, sandbox):
|
||||
"""Normal output should not trigger a retry."""
|
||||
call_count = 0
|
||||
|
||||
def mock_exec(command, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return SimpleNamespace(data=SimpleNamespace(output="all good"))
|
||||
|
||||
sandbox._client.shell.exec_command = mock_exec
|
||||
|
||||
result = sandbox.execute_command("echo hello")
|
||||
assert result == "all good"
|
||||
assert call_count == 1
|
||||
|
||||
|
||||
class TestListDirSerialization:
|
||||
"""Verify that list_dir also acquires the lock."""
|
||||
|
||||
def test_list_dir_uses_lock(self, sandbox):
|
||||
"""list_dir should hold the lock during execution."""
|
||||
lock_was_held = []
|
||||
|
||||
original_exec = MagicMock(return_value=SimpleNamespace(data=SimpleNamespace(output="/a\n/b")))
|
||||
|
||||
def tracking_exec(command, **kwargs):
|
||||
lock_was_held.append(sandbox._lock.locked())
|
||||
return original_exec(command, **kwargs)
|
||||
|
||||
sandbox._client.shell.exec_command = tracking_exec
|
||||
|
||||
result = sandbox.list_dir("/test")
|
||||
assert result == ["/a", "/b"]
|
||||
assert lock_was_held == [True], "list_dir must hold the lock during exec_command"
|
||||
Reference in New Issue
Block a user