mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
refactor(tests): reorganize tests into unittest/ and e2e/ directories
- Move all unit tests from tests/ to tests/unittest/ - Add tests/e2e/ directory for end-to-end tests - Update conftest.py for new test structure - Add new tests for auth dependencies, policies, route injection - Add new tests for run callbacks, create store, execution artifacts - Remove obsolete tests for deleted persistence layer Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,183 @@
|
||||
"""Tests for AioSandbox concurrent command serialization (#1433)."""
|
||||
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def sandbox():
|
||||
"""Create an AioSandbox with a mocked client."""
|
||||
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
|
||||
from deerflow.community.aio_sandbox.aio_sandbox import AioSandbox
|
||||
|
||||
sb = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
|
||||
return sb
|
||||
|
||||
|
||||
class TestExecuteCommandSerialization:
|
||||
"""Verify that concurrent exec_command calls are serialized."""
|
||||
|
||||
def test_lock_prevents_concurrent_execution(self, sandbox):
|
||||
"""Concurrent threads should not overlap inside execute_command."""
|
||||
call_log = []
|
||||
barrier = threading.Barrier(3)
|
||||
|
||||
def slow_exec(command, **kwargs):
|
||||
call_log.append(("enter", command))
|
||||
import time
|
||||
|
||||
time.sleep(0.05)
|
||||
call_log.append(("exit", command))
|
||||
return SimpleNamespace(data=SimpleNamespace(output=f"ok: {command}"))
|
||||
|
||||
sandbox._client.shell.exec_command = slow_exec
|
||||
|
||||
def worker(cmd):
|
||||
barrier.wait() # ensure all threads contend for the lock simultaneously
|
||||
sandbox.execute_command(cmd)
|
||||
|
||||
threads = []
|
||||
for i in range(3):
|
||||
t = threading.Thread(target=worker, args=(f"cmd-{i}",))
|
||||
threads.append(t)
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Verify serialization: each "enter" should be followed by its own
|
||||
# "exit" before the next "enter" (no interleaving).
|
||||
enters = [i for i, (action, _) in enumerate(call_log) if action == "enter"]
|
||||
exits = [i for i, (action, _) in enumerate(call_log) if action == "exit"]
|
||||
assert len(enters) == 3
|
||||
assert len(exits) == 3
|
||||
for e_idx, x_idx in zip(enters, exits):
|
||||
assert x_idx == e_idx + 1, f"Interleaved execution detected: {call_log}"
|
||||
|
||||
|
||||
class TestErrorObservationRetry:
|
||||
"""Verify ErrorObservation detection and fresh-session retry."""
|
||||
|
||||
def test_retry_on_error_observation(self, sandbox):
|
||||
"""When output contains ErrorObservation, retry with a fresh session."""
|
||||
call_count = 0
|
||||
|
||||
def mock_exec(command, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
|
||||
return SimpleNamespace(data=SimpleNamespace(output="success"))
|
||||
|
||||
sandbox._client.shell.exec_command = mock_exec
|
||||
|
||||
result = sandbox.execute_command("echo hello")
|
||||
assert result == "success"
|
||||
assert call_count == 2
|
||||
|
||||
def test_retry_passes_fresh_session_id(self, sandbox):
|
||||
"""The retry call should include a new session id kwarg."""
|
||||
calls = []
|
||||
|
||||
def mock_exec(command, **kwargs):
|
||||
calls.append(kwargs)
|
||||
if len(calls) == 1:
|
||||
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
|
||||
return SimpleNamespace(data=SimpleNamespace(output="ok"))
|
||||
|
||||
sandbox._client.shell.exec_command = mock_exec
|
||||
|
||||
sandbox.execute_command("test")
|
||||
assert len(calls) == 2
|
||||
assert "id" not in calls[0]
|
||||
assert "id" in calls[1]
|
||||
assert len(calls[1]["id"]) == 36 # UUID format
|
||||
|
||||
def test_no_retry_on_clean_output(self, sandbox):
|
||||
"""Normal output should not trigger a retry."""
|
||||
call_count = 0
|
||||
|
||||
def mock_exec(command, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return SimpleNamespace(data=SimpleNamespace(output="all good"))
|
||||
|
||||
sandbox._client.shell.exec_command = mock_exec
|
||||
|
||||
result = sandbox.execute_command("echo hello")
|
||||
assert result == "all good"
|
||||
assert call_count == 1
|
||||
|
||||
|
||||
class TestListDirSerialization:
|
||||
"""Verify that list_dir also acquires the lock."""
|
||||
|
||||
def test_list_dir_uses_lock(self, sandbox):
|
||||
"""list_dir should hold the lock during execution."""
|
||||
lock_was_held = []
|
||||
|
||||
original_exec = MagicMock(return_value=SimpleNamespace(data=SimpleNamespace(output="/a\n/b")))
|
||||
|
||||
def tracking_exec(command, **kwargs):
|
||||
lock_was_held.append(sandbox._lock.locked())
|
||||
return original_exec(command, **kwargs)
|
||||
|
||||
sandbox._client.shell.exec_command = tracking_exec
|
||||
|
||||
result = sandbox.list_dir("/test")
|
||||
assert result == ["/a", "/b"]
|
||||
assert lock_was_held == [True], "list_dir must hold the lock during exec_command"
|
||||
|
||||
|
||||
class TestConcurrentFileWrites:
|
||||
"""Verify file write paths do not lose concurrent updates."""
|
||||
|
||||
def test_append_should_preserve_both_parallel_writes(self, sandbox):
|
||||
storage = {"content": "seed\n"}
|
||||
active_reads = 0
|
||||
state_lock = threading.Lock()
|
||||
overlap_detected = threading.Event()
|
||||
|
||||
def overlapping_read_file(path):
|
||||
nonlocal active_reads
|
||||
with state_lock:
|
||||
active_reads += 1
|
||||
snapshot = storage["content"]
|
||||
if active_reads == 2:
|
||||
overlap_detected.set()
|
||||
|
||||
overlap_detected.wait(0.05)
|
||||
|
||||
with state_lock:
|
||||
active_reads -= 1
|
||||
|
||||
return snapshot
|
||||
|
||||
def write_back(*, file, content, **kwargs):
|
||||
storage["content"] = content
|
||||
return SimpleNamespace(data=SimpleNamespace())
|
||||
|
||||
sandbox.read_file = overlapping_read_file
|
||||
sandbox._client.file.write_file = write_back
|
||||
|
||||
barrier = threading.Barrier(2)
|
||||
|
||||
def writer(payload: str):
|
||||
barrier.wait()
|
||||
sandbox.write_file("/tmp/shared.log", payload, append=True)
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=writer, args=("A\n",)),
|
||||
threading.Thread(target=writer, args=("B\n",)),
|
||||
]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert storage["content"] in {"seed\nA\nB\n", "seed\nB\nA\n"}
|
||||
Reference in New Issue
Block a user