mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-20 15:11:09 +00:00
fix: prevent concurrent subagent file write conflicts in sandbox tools (#1714)
* fix: prevent concurrent subagent file write conflicts Serialize same-path str_replace operations in sandbox tools Guard AioSandbox write_file/update_file with the existing sandbox lock Add regression tests for concurrent str_replace and append races Verify with backend full tests and ruff lint checks * fix(sandbox): Fix the concurrency issue of file operations on the same path in isolated sandboxes. Ensure that different sandbox instances use independent locks for file operations on the same virtual path to avoid concurrency conflicts. Change the lock key from a single path to a composite key of (sandbox.id, path), and add tests to verify the concurrent safety of isolated sandboxes. * feat(sandbox): Extract file operation lock logic to standalone module and fix concurrency issues Extract file operation lock related logic from tools.py into a separate file_operation_lock.py module. Fix data race issues during concurrent str_replace and write_file operations.
This commit is contained in:
@@ -131,3 +131,53 @@ class TestListDirSerialization:
|
||||
result = sandbox.list_dir("/test")
|
||||
assert result == ["/a", "/b"]
|
||||
assert lock_was_held == [True], "list_dir must hold the lock during exec_command"
|
||||
|
||||
|
||||
class TestConcurrentFileWrites:
|
||||
"""Verify file write paths do not lose concurrent updates."""
|
||||
|
||||
def test_append_should_preserve_both_parallel_writes(self, sandbox):
|
||||
storage = {"content": "seed\n"}
|
||||
active_reads = 0
|
||||
state_lock = threading.Lock()
|
||||
overlap_detected = threading.Event()
|
||||
|
||||
def overlapping_read_file(path):
|
||||
nonlocal active_reads
|
||||
with state_lock:
|
||||
active_reads += 1
|
||||
snapshot = storage["content"]
|
||||
if active_reads == 2:
|
||||
overlap_detected.set()
|
||||
|
||||
overlap_detected.wait(0.05)
|
||||
|
||||
with state_lock:
|
||||
active_reads -= 1
|
||||
|
||||
return snapshot
|
||||
|
||||
def write_back(*, file, content, **kwargs):
|
||||
storage["content"] = content
|
||||
return SimpleNamespace(data=SimpleNamespace())
|
||||
|
||||
sandbox.read_file = overlapping_read_file
|
||||
sandbox._client.file.write_file = write_back
|
||||
|
||||
barrier = threading.Barrier(2)
|
||||
|
||||
def writer(payload: str):
|
||||
barrier.wait()
|
||||
sandbox.write_file("/tmp/shared.log", payload, append=True)
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=writer, args=("A\n",)),
|
||||
threading.Thread(target=writer, args=("B\n",)),
|
||||
]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert storage["content"] in {"seed\nA\nB\n", "seed\nB\nA\n"}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
@@ -17,8 +18,10 @@ from deerflow.sandbox.tools import (
|
||||
mask_local_paths_in_output,
|
||||
replace_virtual_path,
|
||||
replace_virtual_paths_in_command,
|
||||
str_replace_tool,
|
||||
validate_local_bash_command_paths,
|
||||
validate_local_tool_path,
|
||||
write_file_tool,
|
||||
)
|
||||
|
||||
_THREAD_DATA = {
|
||||
@@ -512,3 +515,221 @@ def test_validate_local_bash_command_paths_allows_mcp_filesystem_paths() -> None
|
||||
with patch("deerflow.config.extensions_config.get_extensions_config", return_value=disabled_config):
|
||||
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
|
||||
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
|
||||
|
||||
|
||||
def test_str_replace_parallel_updates_should_preserve_both_edits(monkeypatch) -> None:
|
||||
class SharedSandbox:
|
||||
def __init__(self) -> None:
|
||||
self.content = "alpha\nbeta\n"
|
||||
self._active_reads = 0
|
||||
self._state_lock = threading.Lock()
|
||||
self._overlap_detected = threading.Event()
|
||||
|
||||
def read_file(self, path: str) -> str:
|
||||
with self._state_lock:
|
||||
self._active_reads += 1
|
||||
snapshot = self.content
|
||||
if self._active_reads == 2:
|
||||
self._overlap_detected.set()
|
||||
|
||||
self._overlap_detected.wait(0.05)
|
||||
|
||||
with self._state_lock:
|
||||
self._active_reads -= 1
|
||||
|
||||
return snapshot
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
self.content = content
|
||||
|
||||
sandbox = SharedSandbox()
|
||||
runtimes = [
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
|
||||
]
|
||||
failures: list[BaseException] = []
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
|
||||
|
||||
def worker(runtime: SimpleNamespace, old_str: str, new_str: str) -> None:
|
||||
try:
|
||||
result = str_replace_tool.func(
|
||||
runtime=runtime,
|
||||
description="并发替换同一文件",
|
||||
path="/mnt/user-data/workspace/shared.txt",
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
)
|
||||
assert result == "OK"
|
||||
except BaseException as exc: # pragma: no cover - failure is asserted below
|
||||
failures.append(exc)
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=worker, args=(runtimes[0], "alpha", "ALPHA")),
|
||||
threading.Thread(target=worker, args=(runtimes[1], "beta", "BETA")),
|
||||
]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert failures == []
|
||||
assert "ALPHA" in sandbox.content
|
||||
assert "BETA" in sandbox.content
|
||||
|
||||
|
||||
def test_str_replace_parallel_updates_in_isolated_sandboxes_should_not_share_path_lock(monkeypatch) -> None:
|
||||
class IsolatedSandbox:
|
||||
def __init__(self, sandbox_id: str, shared_state: dict[str, object]) -> None:
|
||||
self.id = sandbox_id
|
||||
self.content = "alpha\nbeta\n"
|
||||
self._shared_state = shared_state
|
||||
|
||||
def read_file(self, path: str) -> str:
|
||||
state_lock = self._shared_state["state_lock"]
|
||||
with state_lock:
|
||||
active_reads = self._shared_state["active_reads"]
|
||||
self._shared_state["active_reads"] = active_reads + 1
|
||||
snapshot = self.content
|
||||
if self._shared_state["active_reads"] == 2:
|
||||
overlap_detected = self._shared_state["overlap_detected"]
|
||||
overlap_detected.set()
|
||||
|
||||
overlap_detected = self._shared_state["overlap_detected"]
|
||||
overlap_detected.wait(0.05)
|
||||
|
||||
with state_lock:
|
||||
active_reads = self._shared_state["active_reads"]
|
||||
self._shared_state["active_reads"] = active_reads - 1
|
||||
|
||||
return snapshot
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
self.content = content
|
||||
|
||||
shared_state: dict[str, object] = {
|
||||
"active_reads": 0,
|
||||
"state_lock": threading.Lock(),
|
||||
"overlap_detected": threading.Event(),
|
||||
}
|
||||
sandboxes = {
|
||||
"sandbox-a": IsolatedSandbox("sandbox-a", shared_state),
|
||||
"sandbox-b": IsolatedSandbox("sandbox-b", shared_state),
|
||||
}
|
||||
runtimes = [
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1", "sandbox_key": "sandbox-a"}, config={}),
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-2", "sandbox_key": "sandbox-b"}, config={}),
|
||||
]
|
||||
failures: list[BaseException] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"deerflow.sandbox.tools.ensure_sandbox_initialized",
|
||||
lambda runtime: sandboxes[runtime.context["sandbox_key"]],
|
||||
)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
|
||||
|
||||
def worker(runtime: SimpleNamespace, old_str: str, new_str: str) -> None:
|
||||
try:
|
||||
result = str_replace_tool.func(
|
||||
runtime=runtime,
|
||||
description="隔离 sandbox 并发替换同一路径",
|
||||
path="/mnt/user-data/workspace/shared.txt",
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
)
|
||||
assert result == "OK"
|
||||
except BaseException as exc: # pragma: no cover - failure is asserted below
|
||||
failures.append(exc)
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=worker, args=(runtimes[0], "alpha", "ALPHA")),
|
||||
threading.Thread(target=worker, args=(runtimes[1], "beta", "BETA")),
|
||||
]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert failures == []
|
||||
assert sandboxes["sandbox-a"].content == "ALPHA\nbeta\n"
|
||||
assert sandboxes["sandbox-b"].content == "alpha\nBETA\n"
|
||||
assert shared_state["overlap_detected"].is_set()
|
||||
|
||||
|
||||
def test_str_replace_and_append_on_same_path_should_preserve_both_updates(monkeypatch) -> None:
|
||||
class SharedSandbox:
|
||||
def __init__(self) -> None:
|
||||
self.id = "sandbox-1"
|
||||
self.content = "alpha\n"
|
||||
self.state_lock = threading.Lock()
|
||||
self.str_replace_has_snapshot = threading.Event()
|
||||
self.append_finished = threading.Event()
|
||||
|
||||
def read_file(self, path: str) -> str:
|
||||
with self.state_lock:
|
||||
snapshot = self.content
|
||||
self.str_replace_has_snapshot.set()
|
||||
self.append_finished.wait(0.05)
|
||||
return snapshot
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
with self.state_lock:
|
||||
if append:
|
||||
self.content += content
|
||||
self.append_finished.set()
|
||||
else:
|
||||
self.content = content
|
||||
|
||||
sandbox = SharedSandbox()
|
||||
runtimes = [
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
|
||||
SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}),
|
||||
]
|
||||
failures: list[BaseException] = []
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None)
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False)
|
||||
|
||||
def replace_worker() -> None:
|
||||
try:
|
||||
result = str_replace_tool.func(
|
||||
runtime=runtimes[0],
|
||||
description="替换旧内容",
|
||||
path="/mnt/user-data/workspace/shared.txt",
|
||||
old_str="alpha",
|
||||
new_str="ALPHA",
|
||||
)
|
||||
assert result == "OK"
|
||||
except BaseException as exc: # pragma: no cover - failure is asserted below
|
||||
failures.append(exc)
|
||||
|
||||
def append_worker() -> None:
|
||||
try:
|
||||
sandbox.str_replace_has_snapshot.wait(0.05)
|
||||
result = write_file_tool.func(
|
||||
runtime=runtimes[1],
|
||||
description="追加新内容",
|
||||
path="/mnt/user-data/workspace/shared.txt",
|
||||
content="tail\n",
|
||||
append=True,
|
||||
)
|
||||
assert result == "OK"
|
||||
except BaseException as exc: # pragma: no cover - failure is asserted below
|
||||
failures.append(exc)
|
||||
|
||||
replace_thread = threading.Thread(target=replace_worker)
|
||||
append_thread = threading.Thread(target=append_worker)
|
||||
|
||||
replace_thread.start()
|
||||
append_thread.start()
|
||||
replace_thread.join()
|
||||
append_thread.join()
|
||||
|
||||
assert failures == []
|
||||
assert sandbox.content == "ALPHA\ntail\n"
|
||||
|
||||
Reference in New Issue
Block a user