mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 23:46:50 +00:00
fix(uploads): enforce streaming upload limits in gateway (#2589)
* fix: enforce gateway upload limits * fix: acquire sandbox before upload writes * Fix upload limit config wiring * Sanitize upload size error filenames * test: call upload routes unwrapped * fix: guard upload limits endpoint --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -5,12 +5,35 @@ from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from _router_auth_helpers import call_unwrapped
|
||||
from fastapi import UploadFile
|
||||
import pytest
|
||||
from _router_auth_helpers import call_unwrapped, make_authed_test_app
|
||||
from fastapi import HTTPException, UploadFile
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.routers import uploads
|
||||
|
||||
|
||||
class ChunkedUpload:
|
||||
def __init__(self, filename: str, chunks: list[bytes]):
|
||||
self.filename = filename
|
||||
self._chunks = list(chunks)
|
||||
self.read_calls: list[int | None] = []
|
||||
|
||||
async def read(self, size: int | None = None) -> bytes:
|
||||
self.read_calls.append(size)
|
||||
if size is None:
|
||||
raise AssertionError("upload must be read with an explicit chunk size")
|
||||
if not self._chunks:
|
||||
return b""
|
||||
return self._chunks.pop(0)
|
||||
|
||||
|
||||
def _mounted_provider() -> MagicMock:
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = True
|
||||
return provider
|
||||
|
||||
|
||||
def test_upload_files_writes_thread_storage_and_skips_local_sandbox_sync(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
@@ -178,6 +201,173 @@ def test_upload_files_does_not_adjust_permissions_for_local_sandbox(tmp_path):
|
||||
make_writable.assert_not_called()
|
||||
|
||||
|
||||
def test_upload_files_acquires_non_local_sandbox_before_writing(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = False
|
||||
sandbox = MagicMock()
|
||||
provider.get.return_value = sandbox
|
||||
|
||||
def acquire_before_writes(thread_id: str) -> str:
|
||||
assert list(thread_uploads_dir.iterdir()) == []
|
||||
return "aio-1"
|
||||
|
||||
provider.acquire.side_effect = acquire_before_writes
|
||||
|
||||
with (
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
):
|
||||
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||
|
||||
assert result.success is True
|
||||
provider.acquire.assert_called_once_with("thread-aio")
|
||||
sandbox.update_file.assert_called_once_with("/mnt/user-data/uploads/notes.txt", b"hello uploads")
|
||||
|
||||
|
||||
def test_upload_files_fails_before_writing_when_non_local_sandbox_unavailable(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = False
|
||||
provider.acquire.side_effect = RuntimeError("sandbox unavailable")
|
||||
file = ChunkedUpload("notes.txt", [b"hello uploads"])
|
||||
|
||||
with (
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="sandbox unavailable"):
|
||||
asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||
|
||||
assert list(thread_uploads_dir.iterdir()) == []
|
||||
assert file.read_calls == []
|
||||
provider.get.assert_not_called()
|
||||
|
||||
|
||||
def test_upload_files_rejects_too_many_files_before_writing(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
with (
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=_mounted_provider()),
|
||||
patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=1, max_file_size=10, max_total_size=20)),
|
||||
):
|
||||
files = [
|
||||
ChunkedUpload("one.txt", [b"one"]),
|
||||
ChunkedUpload("two.txt", [b"two"]),
|
||||
]
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=files, config=SimpleNamespace()))
|
||||
|
||||
assert exc_info.value.status_code == 413
|
||||
assert list(thread_uploads_dir.iterdir()) == []
|
||||
assert files[0].read_calls == []
|
||||
assert files[1].read_calls == []
|
||||
|
||||
|
||||
def test_upload_files_rejects_oversized_single_file_and_removes_partial_file(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
provider = _mounted_provider()
|
||||
file = ChunkedUpload("big.txt", [b"123456"])
|
||||
|
||||
with (
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=10, max_file_size=5, max_total_size=20)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||
|
||||
assert exc_info.value.status_code == 413
|
||||
assert not (thread_uploads_dir / "big.txt").exists()
|
||||
assert file.read_calls == [8192]
|
||||
provider.acquire.assert_not_called()
|
||||
|
||||
|
||||
def test_upload_files_rejects_total_size_over_limit_and_cleans_request_files(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
with (
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=_mounted_provider()),
|
||||
patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=10, max_file_size=10, max_total_size=5)),
|
||||
):
|
||||
files = [
|
||||
ChunkedUpload("first.txt", [b"123"]),
|
||||
ChunkedUpload("second.txt", [b"456"]),
|
||||
]
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=files, config=SimpleNamespace()))
|
||||
|
||||
assert exc_info.value.status_code == 413
|
||||
assert not (thread_uploads_dir / "first.txt").exists()
|
||||
assert not (thread_uploads_dir / "second.txt").exists()
|
||||
|
||||
|
||||
def test_upload_files_does_not_sync_non_local_sandbox_when_total_size_exceeds_limit(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = False
|
||||
provider.acquire.return_value = "aio-1"
|
||||
sandbox = MagicMock()
|
||||
provider.get.return_value = sandbox
|
||||
|
||||
with (
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=10, max_file_size=10, max_total_size=5)),
|
||||
):
|
||||
files = [
|
||||
ChunkedUpload("first.txt", [b"123"]),
|
||||
ChunkedUpload("second.txt", [b"456"]),
|
||||
]
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=files, config=SimpleNamespace()))
|
||||
|
||||
assert exc_info.value.status_code == 413
|
||||
provider.acquire.assert_called_once_with("thread-aio")
|
||||
provider.get.assert_called_once_with("aio-1")
|
||||
sandbox.update_file.assert_not_called()
|
||||
|
||||
|
||||
def test_upload_files_does_not_sync_non_local_sandbox_when_conversion_fails(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = False
|
||||
provider.acquire.return_value = "aio-1"
|
||||
sandbox = MagicMock()
|
||||
provider.get.return_value = sandbox
|
||||
|
||||
with (
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
patch.object(uploads, "_auto_convert_documents_enabled", return_value=True),
|
||||
patch.object(uploads, "convert_file_to_markdown", AsyncMock(side_effect=RuntimeError("conversion failed"))),
|
||||
):
|
||||
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
provider.acquire.assert_called_once_with("thread-aio")
|
||||
provider.get.assert_called_once_with("aio-1")
|
||||
sandbox.update_file.assert_not_called()
|
||||
assert not (thread_uploads_dir / "report.pdf").exists()
|
||||
|
||||
|
||||
def test_make_file_sandbox_writable_adds_write_bits_for_regular_files(tmp_path):
|
||||
file_path = tmp_path / "report.pdf"
|
||||
file_path.write_bytes(b"pdf-bytes")
|
||||
@@ -286,3 +476,65 @@ def test_auto_convert_documents_enabled_accepts_boolean_and_string_truthy_values
|
||||
assert uploads._auto_convert_documents_enabled(true_cfg) is True
|
||||
assert uploads._auto_convert_documents_enabled(string_true_cfg) is True
|
||||
assert uploads._auto_convert_documents_enabled(string_false_cfg) is False
|
||||
|
||||
|
||||
def test_upload_limits_endpoint_reads_uploads_config():
|
||||
cfg = MagicMock()
|
||||
cfg.uploads = {
|
||||
"max_files": 15,
|
||||
"max_file_size": "1048576",
|
||||
"max_total_size": 2097152,
|
||||
}
|
||||
|
||||
result = asyncio.run(call_unwrapped(uploads.get_upload_limits, "thread-local", request=MagicMock(), config=cfg))
|
||||
|
||||
assert result.max_files == 15
|
||||
assert result.max_file_size == 1048576
|
||||
assert result.max_total_size == 2097152
|
||||
|
||||
|
||||
def test_upload_limits_endpoint_requires_thread_access():
|
||||
cfg = MagicMock()
|
||||
cfg.uploads = {}
|
||||
app = make_authed_test_app(owner_check_passes=False)
|
||||
app.state.config = cfg
|
||||
app.include_router(uploads.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/threads/thread-local/uploads/limits")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_upload_limits_accept_legacy_config_keys():
|
||||
cfg = MagicMock()
|
||||
cfg.uploads = {
|
||||
"max_file_count": 7,
|
||||
"max_single_file_size": 123,
|
||||
"max_total_size": 456,
|
||||
}
|
||||
|
||||
limits = uploads._get_upload_limits(cfg)
|
||||
|
||||
assert limits == uploads.UploadLimits(max_files=7, max_file_size=123, max_total_size=456)
|
||||
|
||||
|
||||
def test_upload_files_uses_configured_file_count_limit(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
cfg = MagicMock()
|
||||
cfg.uploads = {"max_files": 1}
|
||||
|
||||
with (
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=_mounted_provider()),
|
||||
):
|
||||
files = [
|
||||
ChunkedUpload("one.txt", [b"one"]),
|
||||
ChunkedUpload("two.txt", [b"two"]),
|
||||
]
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=files, config=cfg))
|
||||
|
||||
assert exc_info.value.status_code == 413
|
||||
|
||||
Reference in New Issue
Block a user