mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 08:25:57 +00:00
fix(uploads): enforce streaming upload limits in gateway (#2589)
* fix: enforce gateway upload limits * fix: acquire sandbox before upload writes * Fix upload limit config wiring * Sanitize upload size error filenames * test: call upload routes unwrapped * fix: guard upload limits endpoint --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -30,6 +30,11 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
router = APIRouter(prefix="/api/threads/{thread_id}/uploads", tags=["uploads"])
|
router = APIRouter(prefix="/api/threads/{thread_id}/uploads", tags=["uploads"])
|
||||||
|
|
||||||
|
UPLOAD_CHUNK_SIZE = 8192
|
||||||
|
DEFAULT_MAX_FILES = 10
|
||||||
|
DEFAULT_MAX_FILE_SIZE = 50 * 1024 * 1024
|
||||||
|
DEFAULT_MAX_TOTAL_SIZE = 100 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
class UploadResponse(BaseModel):
|
class UploadResponse(BaseModel):
|
||||||
"""Response model for file upload."""
|
"""Response model for file upload."""
|
||||||
@@ -39,6 +44,14 @@ class UploadResponse(BaseModel):
|
|||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class UploadLimits(BaseModel):
|
||||||
|
"""Application-level upload limits exposed to clients."""
|
||||||
|
|
||||||
|
max_files: int
|
||||||
|
max_file_size: int
|
||||||
|
max_total_size: int
|
||||||
|
|
||||||
|
|
||||||
def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
||||||
"""Ensure uploaded files remain writable when mounted into non-local sandboxes.
|
"""Ensure uploaded files remain writable when mounted into non-local sandboxes.
|
||||||
|
|
||||||
@@ -69,6 +82,62 @@ def _get_uploads_config_value(app_config: AppConfig, key: str, default: object)
|
|||||||
return getattr(uploads_cfg, key, default)
|
return getattr(uploads_cfg, key, default)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_upload_limit(app_config: AppConfig, key: str, default: int, *, legacy_key: str | None = None) -> int:
|
||||||
|
try:
|
||||||
|
value = _get_uploads_config_value(app_config, key, None)
|
||||||
|
if value is None and legacy_key is not None:
|
||||||
|
value = _get_uploads_config_value(app_config, legacy_key, None)
|
||||||
|
if value is None:
|
||||||
|
value = default
|
||||||
|
limit = int(value)
|
||||||
|
if limit <= 0:
|
||||||
|
raise ValueError
|
||||||
|
return limit
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Invalid uploads.%s value; falling back to %d", key, default)
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _get_upload_limits(app_config: AppConfig) -> UploadLimits:
|
||||||
|
return UploadLimits(
|
||||||
|
max_files=_get_upload_limit(app_config, "max_files", DEFAULT_MAX_FILES, legacy_key="max_file_count"),
|
||||||
|
max_file_size=_get_upload_limit(app_config, "max_file_size", DEFAULT_MAX_FILE_SIZE, legacy_key="max_single_file_size"),
|
||||||
|
max_total_size=_get_upload_limit(app_config, "max_total_size", DEFAULT_MAX_TOTAL_SIZE),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_uploaded_paths(paths: list[os.PathLike[str] | str]) -> None:
|
||||||
|
for path in reversed(paths):
|
||||||
|
try:
|
||||||
|
os.unlink(path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to clean up upload path after rejected request: %s", path, exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def _write_upload_file_streaming(
|
||||||
|
file: UploadFile,
|
||||||
|
file_path: os.PathLike[str] | str,
|
||||||
|
*,
|
||||||
|
display_filename: str,
|
||||||
|
max_single_file_size: int,
|
||||||
|
max_total_size: int,
|
||||||
|
total_size: int,
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
file_size = 0
|
||||||
|
with open(file_path, "wb") as output:
|
||||||
|
while chunk := await file.read(UPLOAD_CHUNK_SIZE):
|
||||||
|
file_size += len(chunk)
|
||||||
|
total_size += len(chunk)
|
||||||
|
if file_size > max_single_file_size:
|
||||||
|
raise HTTPException(status_code=413, detail=f"File too large: {display_filename}")
|
||||||
|
if total_size > max_total_size:
|
||||||
|
raise HTTPException(status_code=413, detail="Total upload size too large")
|
||||||
|
output.write(chunk)
|
||||||
|
return file_size, total_size
|
||||||
|
|
||||||
|
|
||||||
def _auto_convert_documents_enabled(app_config: AppConfig) -> bool:
|
def _auto_convert_documents_enabled(app_config: AppConfig) -> bool:
|
||||||
"""Return whether automatic host-side document conversion is enabled.
|
"""Return whether automatic host-side document conversion is enabled.
|
||||||
|
|
||||||
@@ -96,12 +165,19 @@ async def upload_files(
|
|||||||
if not files:
|
if not files:
|
||||||
raise HTTPException(status_code=400, detail="No files provided")
|
raise HTTPException(status_code=400, detail="No files provided")
|
||||||
|
|
||||||
|
limits = _get_upload_limits(config)
|
||||||
|
if len(files) > limits.max_files:
|
||||||
|
raise HTTPException(status_code=413, detail=f"Too many files: maximum is {limits.max_files}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
uploads_dir = ensure_uploads_dir(thread_id)
|
uploads_dir = ensure_uploads_dir(thread_id)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||||
uploaded_files = []
|
uploaded_files = []
|
||||||
|
written_paths = []
|
||||||
|
sandbox_sync_targets = []
|
||||||
|
total_size = 0
|
||||||
|
|
||||||
sandbox_provider = get_sandbox_provider()
|
sandbox_provider = get_sandbox_provider()
|
||||||
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
|
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
|
||||||
@@ -109,6 +185,8 @@ async def upload_files(
|
|||||||
if sync_to_sandbox:
|
if sync_to_sandbox:
|
||||||
sandbox_id = sandbox_provider.acquire(thread_id)
|
sandbox_id = sandbox_provider.acquire(thread_id)
|
||||||
sandbox = sandbox_provider.get(sandbox_id)
|
sandbox = sandbox_provider.get(sandbox_id)
|
||||||
|
if sandbox is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to acquire sandbox")
|
||||||
auto_convert_documents = _auto_convert_documents_enabled(config)
|
auto_convert_documents = _auto_convert_documents_enabled(config)
|
||||||
|
|
||||||
for file in files:
|
for file in files:
|
||||||
@@ -122,35 +200,41 @@ async def upload_files(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content = await file.read()
|
|
||||||
file_path = uploads_dir / safe_filename
|
file_path = uploads_dir / safe_filename
|
||||||
file_path.write_bytes(content)
|
written_paths.append(file_path)
|
||||||
|
file_size, total_size = await _write_upload_file_streaming(
|
||||||
|
file,
|
||||||
|
file_path,
|
||||||
|
display_filename=safe_filename,
|
||||||
|
max_single_file_size=limits.max_file_size,
|
||||||
|
max_total_size=limits.max_total_size,
|
||||||
|
total_size=total_size,
|
||||||
|
)
|
||||||
|
|
||||||
virtual_path = upload_virtual_path(safe_filename)
|
virtual_path = upload_virtual_path(safe_filename)
|
||||||
|
|
||||||
if sync_to_sandbox and sandbox is not None:
|
if sync_to_sandbox:
|
||||||
_make_file_sandbox_writable(file_path)
|
sandbox_sync_targets.append((file_path, virtual_path))
|
||||||
sandbox.update_file(virtual_path, content)
|
|
||||||
|
|
||||||
file_info = {
|
file_info = {
|
||||||
"filename": safe_filename,
|
"filename": safe_filename,
|
||||||
"size": str(len(content)),
|
"size": str(file_size),
|
||||||
"path": str(sandbox_uploads / safe_filename),
|
"path": str(sandbox_uploads / safe_filename),
|
||||||
"virtual_path": virtual_path,
|
"virtual_path": virtual_path,
|
||||||
"artifact_url": upload_artifact_url(thread_id, safe_filename),
|
"artifact_url": upload_artifact_url(thread_id, safe_filename),
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
|
logger.info(f"Saved file: {safe_filename} ({file_size} bytes) to {file_info['path']}")
|
||||||
|
|
||||||
file_ext = file_path.suffix.lower()
|
file_ext = file_path.suffix.lower()
|
||||||
if auto_convert_documents and file_ext in CONVERTIBLE_EXTENSIONS:
|
if auto_convert_documents and file_ext in CONVERTIBLE_EXTENSIONS:
|
||||||
md_path = await convert_file_to_markdown(file_path)
|
md_path = await convert_file_to_markdown(file_path)
|
||||||
if md_path:
|
if md_path:
|
||||||
|
written_paths.append(md_path)
|
||||||
md_virtual_path = upload_virtual_path(md_path.name)
|
md_virtual_path = upload_virtual_path(md_path.name)
|
||||||
|
|
||||||
if sync_to_sandbox and sandbox is not None:
|
if sync_to_sandbox:
|
||||||
_make_file_sandbox_writable(md_path)
|
sandbox_sync_targets.append((md_path, md_virtual_path))
|
||||||
sandbox.update_file(md_virtual_path, md_path.read_bytes())
|
|
||||||
|
|
||||||
file_info["markdown_file"] = md_path.name
|
file_info["markdown_file"] = md_path.name
|
||||||
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
|
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
|
||||||
@@ -159,10 +243,19 @@ async def upload_files(
|
|||||||
|
|
||||||
uploaded_files.append(file_info)
|
uploaded_files.append(file_info)
|
||||||
|
|
||||||
|
except HTTPException as e:
|
||||||
|
_cleanup_uploaded_paths(written_paths)
|
||||||
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to upload {file.filename}: {e}")
|
logger.error(f"Failed to upload {file.filename}: {e}")
|
||||||
|
_cleanup_uploaded_paths(written_paths)
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
|
||||||
|
|
||||||
|
if sync_to_sandbox:
|
||||||
|
for file_path, virtual_path in sandbox_sync_targets:
|
||||||
|
_make_file_sandbox_writable(file_path)
|
||||||
|
sandbox.update_file(virtual_path, file_path.read_bytes())
|
||||||
|
|
||||||
return UploadResponse(
|
return UploadResponse(
|
||||||
success=True,
|
success=True,
|
||||||
files=uploaded_files,
|
files=uploaded_files,
|
||||||
@@ -170,6 +263,17 @@ async def upload_files(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/limits", response_model=UploadLimits)
|
||||||
|
@require_permission("threads", "read", owner_check=True)
|
||||||
|
async def get_upload_limits(
|
||||||
|
thread_id: str,
|
||||||
|
request: Request,
|
||||||
|
config: AppConfig = Depends(get_config),
|
||||||
|
) -> UploadLimits:
|
||||||
|
"""Return upload limits used by the gateway for this thread."""
|
||||||
|
return _get_upload_limits(config)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/list", response_model=dict)
|
@router.get("/list", response_model=dict)
|
||||||
@require_permission("threads", "read", owner_check=True)
|
@require_permission("threads", "read", owner_check=True)
|
||||||
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
|
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ POST /api/threads/{thread_id}/uploads
|
|||||||
**请求体:** `multipart/form-data`
|
**请求体:** `multipart/form-data`
|
||||||
- `files`: 一个或多个文件
|
- `files`: 一个或多个文件
|
||||||
|
|
||||||
|
网关会在应用层限制上传规模,默认最多 10 个文件、单文件 50 MiB、单次请求总计 100 MiB。可通过 `config.yaml` 的 `uploads.max_files`、`uploads.max_file_size`、`uploads.max_total_size` 调整;前端会读取同一组限制并在选择文件时提示,超过限制时后端返回 `413 Payload Too Large`。
|
||||||
|
|
||||||
**响应:**
|
**响应:**
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@@ -48,7 +50,23 @@ POST /api/threads/{thread_id}/uploads
|
|||||||
- `virtual_path`: Agent 在沙箱中使用的虚拟路径
|
- `virtual_path`: Agent 在沙箱中使用的虚拟路径
|
||||||
- `artifact_url`: 前端通过 HTTP 访问文件的 URL
|
- `artifact_url`: 前端通过 HTTP 访问文件的 URL
|
||||||
|
|
||||||
### 2. 列出已上传文件
|
### 2. 查询上传限制
|
||||||
|
```
|
||||||
|
GET /api/threads/{thread_id}/uploads/limits
|
||||||
|
```
|
||||||
|
|
||||||
|
返回网关当前生效的上传限制,供前端在用户选择文件前提示和拦截。
|
||||||
|
|
||||||
|
**响应:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"max_files": 10,
|
||||||
|
"max_file_size": 52428800,
|
||||||
|
"max_total_size": 104857600
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 列出已上传文件
|
||||||
```
|
```
|
||||||
GET /api/threads/{thread_id}/uploads/list
|
GET /api/threads/{thread_id}/uploads/list
|
||||||
```
|
```
|
||||||
@@ -71,7 +89,7 @@ GET /api/threads/{thread_id}/uploads/list
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. 删除文件
|
### 4. 删除文件
|
||||||
```
|
```
|
||||||
DELETE /api/threads/{thread_id}/uploads/{filename}
|
DELETE /api/threads/{thread_id}/uploads/{filename}
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -5,12 +5,35 @@ from pathlib import Path
|
|||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
from _router_auth_helpers import call_unwrapped
|
import pytest
|
||||||
from fastapi import UploadFile
|
from _router_auth_helpers import call_unwrapped, make_authed_test_app
|
||||||
|
from fastapi import HTTPException, UploadFile
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from app.gateway.routers import uploads
|
from app.gateway.routers import uploads
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkedUpload:
|
||||||
|
def __init__(self, filename: str, chunks: list[bytes]):
|
||||||
|
self.filename = filename
|
||||||
|
self._chunks = list(chunks)
|
||||||
|
self.read_calls: list[int | None] = []
|
||||||
|
|
||||||
|
async def read(self, size: int | None = None) -> bytes:
|
||||||
|
self.read_calls.append(size)
|
||||||
|
if size is None:
|
||||||
|
raise AssertionError("upload must be read with an explicit chunk size")
|
||||||
|
if not self._chunks:
|
||||||
|
return b""
|
||||||
|
return self._chunks.pop(0)
|
||||||
|
|
||||||
|
|
||||||
|
def _mounted_provider() -> MagicMock:
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.uses_thread_data_mounts = True
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
def test_upload_files_writes_thread_storage_and_skips_local_sandbox_sync(tmp_path):
|
def test_upload_files_writes_thread_storage_and_skips_local_sandbox_sync(tmp_path):
|
||||||
thread_uploads_dir = tmp_path / "uploads"
|
thread_uploads_dir = tmp_path / "uploads"
|
||||||
thread_uploads_dir.mkdir(parents=True)
|
thread_uploads_dir.mkdir(parents=True)
|
||||||
@@ -178,6 +201,173 @@ def test_upload_files_does_not_adjust_permissions_for_local_sandbox(tmp_path):
|
|||||||
make_writable.assert_not_called()
|
make_writable.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_files_acquires_non_local_sandbox_before_writing(tmp_path):
|
||||||
|
thread_uploads_dir = tmp_path / "uploads"
|
||||||
|
thread_uploads_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.uses_thread_data_mounts = False
|
||||||
|
sandbox = MagicMock()
|
||||||
|
provider.get.return_value = sandbox
|
||||||
|
|
||||||
|
def acquire_before_writes(thread_id: str) -> str:
|
||||||
|
assert list(thread_uploads_dir.iterdir()) == []
|
||||||
|
return "aio-1"
|
||||||
|
|
||||||
|
provider.acquire.side_effect = acquire_before_writes
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||||
|
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||||
|
):
|
||||||
|
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
|
||||||
|
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
provider.acquire.assert_called_once_with("thread-aio")
|
||||||
|
sandbox.update_file.assert_called_once_with("/mnt/user-data/uploads/notes.txt", b"hello uploads")
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_files_fails_before_writing_when_non_local_sandbox_unavailable(tmp_path):
|
||||||
|
thread_uploads_dir = tmp_path / "uploads"
|
||||||
|
thread_uploads_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.uses_thread_data_mounts = False
|
||||||
|
provider.acquire.side_effect = RuntimeError("sandbox unavailable")
|
||||||
|
file = ChunkedUpload("notes.txt", [b"hello uploads"])
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||||
|
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||||
|
):
|
||||||
|
with pytest.raises(RuntimeError, match="sandbox unavailable"):
|
||||||
|
asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||||
|
|
||||||
|
assert list(thread_uploads_dir.iterdir()) == []
|
||||||
|
assert file.read_calls == []
|
||||||
|
provider.get.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_files_rejects_too_many_files_before_writing(tmp_path):
|
||||||
|
thread_uploads_dir = tmp_path / "uploads"
|
||||||
|
thread_uploads_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||||
|
patch.object(uploads, "get_sandbox_provider", return_value=_mounted_provider()),
|
||||||
|
patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=1, max_file_size=10, max_total_size=20)),
|
||||||
|
):
|
||||||
|
files = [
|
||||||
|
ChunkedUpload("one.txt", [b"one"]),
|
||||||
|
ChunkedUpload("two.txt", [b"two"]),
|
||||||
|
]
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=files, config=SimpleNamespace()))
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 413
|
||||||
|
assert list(thread_uploads_dir.iterdir()) == []
|
||||||
|
assert files[0].read_calls == []
|
||||||
|
assert files[1].read_calls == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_files_rejects_oversized_single_file_and_removes_partial_file(tmp_path):
|
||||||
|
thread_uploads_dir = tmp_path / "uploads"
|
||||||
|
thread_uploads_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
provider = _mounted_provider()
|
||||||
|
file = ChunkedUpload("big.txt", [b"123456"])
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||||
|
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||||
|
patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=10, max_file_size=5, max_total_size=20)),
|
||||||
|
):
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 413
|
||||||
|
assert not (thread_uploads_dir / "big.txt").exists()
|
||||||
|
assert file.read_calls == [8192]
|
||||||
|
provider.acquire.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_files_rejects_total_size_over_limit_and_cleans_request_files(tmp_path):
|
||||||
|
thread_uploads_dir = tmp_path / "uploads"
|
||||||
|
thread_uploads_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||||
|
patch.object(uploads, "get_sandbox_provider", return_value=_mounted_provider()),
|
||||||
|
patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=10, max_file_size=10, max_total_size=5)),
|
||||||
|
):
|
||||||
|
files = [
|
||||||
|
ChunkedUpload("first.txt", [b"123"]),
|
||||||
|
ChunkedUpload("second.txt", [b"456"]),
|
||||||
|
]
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=files, config=SimpleNamespace()))
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 413
|
||||||
|
assert not (thread_uploads_dir / "first.txt").exists()
|
||||||
|
assert not (thread_uploads_dir / "second.txt").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_files_does_not_sync_non_local_sandbox_when_total_size_exceeds_limit(tmp_path):
|
||||||
|
thread_uploads_dir = tmp_path / "uploads"
|
||||||
|
thread_uploads_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.uses_thread_data_mounts = False
|
||||||
|
provider.acquire.return_value = "aio-1"
|
||||||
|
sandbox = MagicMock()
|
||||||
|
provider.get.return_value = sandbox
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||||
|
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||||
|
patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=10, max_file_size=10, max_total_size=5)),
|
||||||
|
):
|
||||||
|
files = [
|
||||||
|
ChunkedUpload("first.txt", [b"123"]),
|
||||||
|
ChunkedUpload("second.txt", [b"456"]),
|
||||||
|
]
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=files, config=SimpleNamespace()))
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 413
|
||||||
|
provider.acquire.assert_called_once_with("thread-aio")
|
||||||
|
provider.get.assert_called_once_with("aio-1")
|
||||||
|
sandbox.update_file.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_files_does_not_sync_non_local_sandbox_when_conversion_fails(tmp_path):
|
||||||
|
thread_uploads_dir = tmp_path / "uploads"
|
||||||
|
thread_uploads_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.uses_thread_data_mounts = False
|
||||||
|
provider.acquire.return_value = "aio-1"
|
||||||
|
sandbox = MagicMock()
|
||||||
|
provider.get.return_value = sandbox
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||||
|
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||||
|
patch.object(uploads, "_auto_convert_documents_enabled", return_value=True),
|
||||||
|
patch.object(uploads, "convert_file_to_markdown", AsyncMock(side_effect=RuntimeError("conversion failed"))),
|
||||||
|
):
|
||||||
|
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 500
|
||||||
|
provider.acquire.assert_called_once_with("thread-aio")
|
||||||
|
provider.get.assert_called_once_with("aio-1")
|
||||||
|
sandbox.update_file.assert_not_called()
|
||||||
|
assert not (thread_uploads_dir / "report.pdf").exists()
|
||||||
|
|
||||||
|
|
||||||
def test_make_file_sandbox_writable_adds_write_bits_for_regular_files(tmp_path):
|
def test_make_file_sandbox_writable_adds_write_bits_for_regular_files(tmp_path):
|
||||||
file_path = tmp_path / "report.pdf"
|
file_path = tmp_path / "report.pdf"
|
||||||
file_path.write_bytes(b"pdf-bytes")
|
file_path.write_bytes(b"pdf-bytes")
|
||||||
@@ -286,3 +476,65 @@ def test_auto_convert_documents_enabled_accepts_boolean_and_string_truthy_values
|
|||||||
assert uploads._auto_convert_documents_enabled(true_cfg) is True
|
assert uploads._auto_convert_documents_enabled(true_cfg) is True
|
||||||
assert uploads._auto_convert_documents_enabled(string_true_cfg) is True
|
assert uploads._auto_convert_documents_enabled(string_true_cfg) is True
|
||||||
assert uploads._auto_convert_documents_enabled(string_false_cfg) is False
|
assert uploads._auto_convert_documents_enabled(string_false_cfg) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_limits_endpoint_reads_uploads_config():
|
||||||
|
cfg = MagicMock()
|
||||||
|
cfg.uploads = {
|
||||||
|
"max_files": 15,
|
||||||
|
"max_file_size": "1048576",
|
||||||
|
"max_total_size": 2097152,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = asyncio.run(call_unwrapped(uploads.get_upload_limits, "thread-local", request=MagicMock(), config=cfg))
|
||||||
|
|
||||||
|
assert result.max_files == 15
|
||||||
|
assert result.max_file_size == 1048576
|
||||||
|
assert result.max_total_size == 2097152
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_limits_endpoint_requires_thread_access():
|
||||||
|
cfg = MagicMock()
|
||||||
|
cfg.uploads = {}
|
||||||
|
app = make_authed_test_app(owner_check_passes=False)
|
||||||
|
app.state.config = cfg
|
||||||
|
app.include_router(uploads.router)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/threads/thread-local/uploads/limits")
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_limits_accept_legacy_config_keys():
|
||||||
|
cfg = MagicMock()
|
||||||
|
cfg.uploads = {
|
||||||
|
"max_file_count": 7,
|
||||||
|
"max_single_file_size": 123,
|
||||||
|
"max_total_size": 456,
|
||||||
|
}
|
||||||
|
|
||||||
|
limits = uploads._get_upload_limits(cfg)
|
||||||
|
|
||||||
|
assert limits == uploads.UploadLimits(max_files=7, max_file_size=123, max_total_size=456)
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_files_uses_configured_file_count_limit(tmp_path):
|
||||||
|
thread_uploads_dir = tmp_path / "uploads"
|
||||||
|
thread_uploads_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
cfg = MagicMock()
|
||||||
|
cfg.uploads = {"max_files": 1}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||||
|
patch.object(uploads, "get_sandbox_provider", return_value=_mounted_provider()),
|
||||||
|
):
|
||||||
|
files = [
|
||||||
|
ChunkedUpload("one.txt", [b"one"]),
|
||||||
|
ChunkedUpload("two.txt", [b"two"]),
|
||||||
|
]
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=files, config=cfg))
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 413
|
||||||
|
|||||||
@@ -501,6 +501,11 @@ tool_search:
|
|||||||
# Option 1: Local Sandbox (Default)
|
# Option 1: Local Sandbox (Default)
|
||||||
# Executes commands directly on the host machine
|
# Executes commands directly on the host machine
|
||||||
uploads:
|
uploads:
|
||||||
|
# Application-level upload limits enforced by the gateway and exposed to the
|
||||||
|
# frontend before file selection.
|
||||||
|
max_files: 10
|
||||||
|
max_file_size: 52428800 # 50 MiB
|
||||||
|
max_total_size: 104857600 # 100 MiB
|
||||||
# Automatic Office/PDF conversion runs on the backend host before sandbox
|
# Automatic Office/PDF conversion runs on the backend host before sandbox
|
||||||
# isolation applies. Keep this disabled unless uploads come from a fully
|
# isolation applies. Keep this disabled unless uploads come from a fully
|
||||||
# trusted source and you intentionally accept host-side parser risk.
|
# trusted source and you intentionally accept host-side parser risk.
|
||||||
|
|||||||
Reference in New Issue
Block a user