mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
fix(backend): fix uploads for mounted sandbox providers (#2199)
* fix uploads for mounted sandbox providers * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -8,7 +8,7 @@ from fastapi import APIRouter, File, HTTPException, UploadFile
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from deerflow.config.paths import get_paths
|
from deerflow.config.paths import get_paths
|
||||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
|
||||||
from deerflow.uploads.manager import (
|
from deerflow.uploads.manager import (
|
||||||
PathTraversalError,
|
PathTraversalError,
|
||||||
delete_file_safe,
|
delete_file_safe,
|
||||||
@@ -53,6 +53,10 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
|||||||
os.chmod(file_path, writable_mode, **chmod_kwargs)
|
os.chmod(file_path, writable_mode, **chmod_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool:
|
||||||
|
return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False))
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=UploadResponse)
|
@router.post("", response_model=UploadResponse)
|
||||||
async def upload_files(
|
async def upload_files(
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
@@ -70,8 +74,11 @@ async def upload_files(
|
|||||||
uploaded_files = []
|
uploaded_files = []
|
||||||
|
|
||||||
sandbox_provider = get_sandbox_provider()
|
sandbox_provider = get_sandbox_provider()
|
||||||
sandbox_id = sandbox_provider.acquire(thread_id)
|
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
|
||||||
sandbox = sandbox_provider.get(sandbox_id)
|
sandbox = None
|
||||||
|
if sync_to_sandbox:
|
||||||
|
sandbox_id = sandbox_provider.acquire(thread_id)
|
||||||
|
sandbox = sandbox_provider.get(sandbox_id)
|
||||||
|
|
||||||
for file in files:
|
for file in files:
|
||||||
if not file.filename:
|
if not file.filename:
|
||||||
@@ -90,7 +97,7 @@ async def upload_files(
|
|||||||
|
|
||||||
virtual_path = upload_virtual_path(safe_filename)
|
virtual_path = upload_virtual_path(safe_filename)
|
||||||
|
|
||||||
if sandbox_id != "local":
|
if sync_to_sandbox and sandbox is not None:
|
||||||
_make_file_sandbox_writable(file_path)
|
_make_file_sandbox_writable(file_path)
|
||||||
sandbox.update_file(virtual_path, content)
|
sandbox.update_file(virtual_path, content)
|
||||||
|
|
||||||
@@ -110,7 +117,7 @@ async def upload_files(
|
|||||||
if md_path:
|
if md_path:
|
||||||
md_virtual_path = upload_virtual_path(md_path.name)
|
md_virtual_path = upload_virtual_path(md_path.name)
|
||||||
|
|
||||||
if sandbox_id != "local":
|
if sync_to_sandbox and sandbox is not None:
|
||||||
_make_file_sandbox_writable(md_path)
|
_make_file_sandbox_writable(md_path)
|
||||||
sandbox.update_file(md_virtual_path, md_path.read_bytes())
|
sandbox.update_file(md_virtual_path, md_path.read_bytes())
|
||||||
|
|
||||||
|
|||||||
@@ -119,6 +119,16 @@ class AioSandboxProvider(SandboxProvider):
|
|||||||
if self._config.get("idle_timeout", DEFAULT_IDLE_TIMEOUT) > 0:
|
if self._config.get("idle_timeout", DEFAULT_IDLE_TIMEOUT) > 0:
|
||||||
self._start_idle_checker()
|
self._start_idle_checker()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def uses_thread_data_mounts(self) -> bool:
|
||||||
|
"""Whether thread workspace/uploads/outputs are visible via mounts.
|
||||||
|
|
||||||
|
Local container backends bind-mount the thread data directories, so files
|
||||||
|
written by the gateway are already visible when the sandbox starts.
|
||||||
|
Remote backends may require explicit file sync.
|
||||||
|
"""
|
||||||
|
return isinstance(self._backend, LocalContainerBackend)
|
||||||
|
|
||||||
# ── Factory methods ──────────────────────────────────────────────────
|
# ── Factory methods ──────────────────────────────────────────────────
|
||||||
|
|
||||||
def _create_backend(self) -> SandboxBackend:
|
def _create_backend(self) -> SandboxBackend:
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ _singleton: LocalSandbox | None = None
|
|||||||
|
|
||||||
|
|
||||||
class LocalSandboxProvider(SandboxProvider):
|
class LocalSandboxProvider(SandboxProvider):
|
||||||
|
uses_thread_data_mounts = True
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize the local sandbox provider with path mappings."""
|
"""Initialize the local sandbox provider with path mappings."""
|
||||||
self._path_mappings = self._setup_path_mappings()
|
self._path_mappings = self._setup_path_mappings()
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from deerflow.sandbox.sandbox import Sandbox
|
|||||||
class SandboxProvider(ABC):
|
class SandboxProvider(ABC):
|
||||||
"""Abstract base class for sandbox providers"""
|
"""Abstract base class for sandbox providers"""
|
||||||
|
|
||||||
|
uses_thread_data_mounts: bool = False
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def acquire(self, thread_id: str | None = None) -> str:
|
def acquire(self, thread_id: str | None = None) -> str:
|
||||||
"""Acquire a sandbox environment and return its ID.
|
"""Acquire a sandbox environment and return its ID.
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ def test_upload_files_writes_thread_storage_and_skips_local_sandbox_sync(tmp_pat
|
|||||||
thread_uploads_dir.mkdir(parents=True)
|
thread_uploads_dir.mkdir(parents=True)
|
||||||
|
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
|
provider.uses_thread_data_mounts = True
|
||||||
provider.acquire.return_value = "local"
|
provider.acquire.return_value = "local"
|
||||||
sandbox = MagicMock()
|
sandbox = MagicMock()
|
||||||
provider.get.return_value = sandbox
|
provider.get.return_value = sandbox
|
||||||
@@ -34,11 +35,33 @@ def test_upload_files_writes_thread_storage_and_skips_local_sandbox_sync(tmp_pat
|
|||||||
sandbox.update_file.assert_not_called()
|
sandbox.update_file.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_files_skips_acquire_when_thread_data_is_mounted(tmp_path):
|
||||||
|
thread_uploads_dir = tmp_path / "uploads"
|
||||||
|
thread_uploads_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.uses_thread_data_mounts = True
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir),
|
||||||
|
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||||
|
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||||
|
):
|
||||||
|
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
|
||||||
|
result = asyncio.run(uploads.upload_files("thread-mounted", files=[file]))
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert (thread_uploads_dir / "notes.txt").read_bytes() == b"hello uploads"
|
||||||
|
provider.acquire.assert_not_called()
|
||||||
|
provider.get.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
def test_upload_files_syncs_non_local_sandbox_and_marks_markdown_file(tmp_path):
|
def test_upload_files_syncs_non_local_sandbox_and_marks_markdown_file(tmp_path):
|
||||||
thread_uploads_dir = tmp_path / "uploads"
|
thread_uploads_dir = tmp_path / "uploads"
|
||||||
thread_uploads_dir.mkdir(parents=True)
|
thread_uploads_dir.mkdir(parents=True)
|
||||||
|
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
|
provider.uses_thread_data_mounts = False
|
||||||
provider.acquire.return_value = "aio-1"
|
provider.acquire.return_value = "aio-1"
|
||||||
sandbox = MagicMock()
|
sandbox = MagicMock()
|
||||||
provider.get.return_value = sandbox
|
provider.get.return_value = sandbox
|
||||||
@@ -75,6 +98,7 @@ def test_upload_files_makes_non_local_files_sandbox_writable(tmp_path):
|
|||||||
thread_uploads_dir.mkdir(parents=True)
|
thread_uploads_dir.mkdir(parents=True)
|
||||||
|
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
|
provider.uses_thread_data_mounts = False
|
||||||
provider.acquire.return_value = "aio-1"
|
provider.acquire.return_value = "aio-1"
|
||||||
sandbox = MagicMock()
|
sandbox = MagicMock()
|
||||||
provider.get.return_value = sandbox
|
provider.get.return_value = sandbox
|
||||||
@@ -104,6 +128,7 @@ def test_upload_files_does_not_adjust_permissions_for_local_sandbox(tmp_path):
|
|||||||
thread_uploads_dir.mkdir(parents=True)
|
thread_uploads_dir.mkdir(parents=True)
|
||||||
|
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
|
provider.uses_thread_data_mounts = True
|
||||||
provider.acquire.return_value = "local"
|
provider.acquire.return_value = "local"
|
||||||
sandbox = MagicMock()
|
sandbox = MagicMock()
|
||||||
provider.get.return_value = sandbox
|
provider.get.return_value = sandbox
|
||||||
|
|||||||
Reference in New Issue
Block a user